Source

ScalaEuler / src / com / mklinke / euler / Problem41.scala

/**
 * *
 *  Copyright 2012 Martin Klinke, http://www.martinklinke.com.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package com.mklinke.euler

/**
 * @author Martin Klinke
 *
 */
object Problem41 {

  def main(args: Array[String]) {
    val start = System.currentTimeMillis;
    for (i <- (1 to 9).reverse) {
      val it = findAllPandigitalPrimes(i)
      if (it.hasNext)
        println("n = " + i + ", max pandigital prime: " + it.max)
    }
    println("Duration: " + (System.currentTimeMillis - start) + " ms");
  }

  def findAllPandigitalPrimes(n: Int): Iterator[Long] = {
    val permutations = (1 to n).permutations
    findPandigitalPrimes(permutations).iterator
  }

  def findPandigitalPrimes(it: Iterator[IndexedSeq[Int]]): Stream[Long] = {
    def loop(it: Iterator[IndexedSeq[Int]]): Stream[Long] = {
      if (it.hasNext) {
        val current = it.next.mkString.toLong
        if (isPrime(current))
          current #:: loop(it)
        else
          loop(it)
      } else {
        Stream.empty
      }
    }
    loop(it)
  }

  val expectedSum = (1 to 9).sum
  val expectedProduct = (1 to 9).product

  def isPandigital(n: String): Boolean = {
    val digits = n.toList.map(c => c - 48)
    digits.sum == expectedSum && digits.product == expectedProduct
  }

  def isPrime(n: Long): Boolean = {
    if (n == 1)
      false
    else
      n == smallestDivisor(n)
  }

  def smallestDivisor(n: Long): Long = {
    findDivisor(n, 2)
  }

  def findDivisor(n: Long, testDivisor: Long): Long = {
    if (testDivisor * testDivisor > n)
      n
    else if (n % testDivisor == 0)
      testDivisor
    else
      findDivisor(n, testDivisor + 1)
  }
}