Summing a large list of numbers is too slow - haskell

Task: "Sum the first 15,000,000 even numbers."
Haskell:
nats = [1..] :: [Int]
evens = filter even nats :: [Int]
MySum:: Int
MySum= sum $ take 15000000 evens
...but MySum takes ages. More precisely, about 10-20 times slower than C/C++.
Many times I've found, that a Haskell solution coded naturally is something like 10 times slower than C. I expected that GHC was a very neatly optimizing compiler and task such this don't seem that tough.
So, one would expect something like 1.5-2x slower than C. Where is the problem?
Can this be solved better?
This is the C code I'm comparing it with:
long long sum = 0;
int n = 0, i = 1;
for (;;) {
if (i % 2 == 0) {
sum += i;
n++;
}
if (n == 15000000)
break;
i++;
}
Edit 1: I really know, that it can be computed in O(1). Please, resist.
Edit 2: I really know, that evens are [2,4..] but the function even could be something else O(1) and need to be implemented as a function.

Lists are not loops
So don't be surprised if using lists as a loop replacement, you get slower code if the loop body is small.
nats = [1..] :: [Int]
evens = filter even nats :: [Int]
dumbSum :: Int
dumbSum = sum $ take 15000000 evens
sum is not a "good consumer", so GHC is not (yet) able to eliminate the intermediate lists completely.
If you compile with optimisations (and don't export nat), GHC is smart enough to fuse the filter with the enumeration,
Rec {
Main.main_go [Occ=LoopBreaker]
:: GHC.Prim.Int# -> GHC.Prim.Int# -> [GHC.Types.Int]
[GblId, Arity=1, Caf=NoCafRefs, Str=DmdType L]
Main.main_go =
\ (x_aV2 :: GHC.Prim.Int#) ->
let {
r_au7 :: GHC.Prim.Int# -> [GHC.Types.Int]
[LclId, Str=DmdType]
r_au7 =
case x_aV2 of wild_Xl {
__DEFAULT -> Main.main_go (GHC.Prim.+# wild_Xl 1);
9223372036854775807 -> n_r1RR
} } in
case GHC.Prim.remInt# x_aV2 2 of _ {
__DEFAULT -> r_au7;
0 ->
let {
wild_atm :: GHC.Types.Int
[LclId, Str=DmdType m]
wild_atm = GHC.Types.I# x_aV2 } in
let {
lvl_s1Rp :: [GHC.Types.Int]
[LclId]
lvl_s1Rp =
GHC.Types.:
# GHC.Types.Int wild_atm (GHC.Types.[] # GHC.Types.Int) } in
\ (m_aUL :: GHC.Prim.Int#) ->
case GHC.Prim.<=# m_aUL 1 of _ {
GHC.Types.False ->
GHC.Types.: # GHC.Types.Int wild_atm (r_au7 (GHC.Prim.-# m_aUL 1));
GHC.Types.True -> lvl_s1Rp
}
}
end Rec }
but that's as far as GHC's fusion takes it. You are left with boxing Ints and constructing list cells. If you give it a loop, like you give it to the C compiler,
module Main where
import Data.Bits
main :: IO ()
main = print dumbSum
dumbSum :: Int
dumbSum = go 0 0 1
where
go :: Int -> Int -> Int -> Int
go sm ct n
| ct >= 15000000 = sm
| n .&. 1 == 0 = go (sm + n) (ct+1) (n+1)
| otherwise = go sm ct (n+1)
you get the approximate relation of running times between the C and the Haskell version you expected.
This sort of algorithm is not what GHC has been taught to optimise well, there are bigger fish to fry elsewhere before the limited manpower is put into these optimisations.

The problem why list fusion can't work here is actually rather subtle. Say we define the right RULE to fuse the list away:
import GHC.Base
sum2 :: Num a => [a] -> a
sum2 = sum
{-# NOINLINE [1] sum2 #-}
{-# RULES "sum" forall (f :: forall b. (a->b->b)->b->b).
sum2 (build f) = f (+) 0 #-}
(The short explanation is that we define sum2 as an alias of sum, which we forbid GHC to inline early, so the RULE has a chance to fire before sum2 gets eliminated. Then we look for sum2 directly next to the list-builder build (see definition) and replace it by direct arithmetic.)
This has mixed success, as it yields the following Core:
Main.$wgo =
\ (w_s1T4 :: GHC.Prim.Int#) ->
case GHC.Prim.remInt# w_s1T4 2 of _ {
__DEFAULT ->
case w_s1T4 of wild_Xg {
__DEFAULT -> Main.$wgo (GHC.Prim.+# wild_Xg 1);
15000000 -> 0
};
0 ->
case w_s1T4 of wild_Xg {
__DEFAULT ->
case Main.$wgo (GHC.Prim.+# wild_Xg 1) of ww_s1T7 { __DEFAULT ->
GHC.Prim.+# wild_Xg ww_s1T7
};
15000000 -> 15000000
}
}
Which is nice, completely fused code - with the sole problem that we have a call to $wgo in a non-tail-call position. This means that we aren't looking at a loop, but actually at a deeply recursive function, with predictable program results:
Stack space overflow: current size 8388608 bytes.
The root problem here is that the Prelude's list fusion can only fuse right folds, and computing the sum as a right fold directly causes the excessive stack consumption.
The obvious fix would be to use a fusion framework that can actually deal with left folds, such as Duncan's stream-fusion package, which actually implements sum fusion.
Another solution would be to hack around it - and implement the left fold using a right fold:
main = print $ foldr (\x c -> c . (+x)) id [2,4..15000000] 0
This actually produces close-to-perfect code with current versions of GHC. On the other hand, this is generally not a good idea as it relies on GHC being smart enough to eliminate the partially applied functions. Already adding a filter into the chain will break that particular optimization.

Sum first 15,000,000 even numbers:
{-# LANGUAGE BangPatterns #-}
g :: Integer -- 15000000*15000001 = 225000015000000
g = go 1 0 0
where
go i !a c | c == 15000000 = a
go i !a c | even i = go (i+1) (a+i) (c+1)
go i !a c = go (i+1) a c
ought to be the fastest.

If you want to be sure to traverse the list only once, you can write the traversal explicitly:
nats = [1..] :: [Int]
requiredOfX :: Int -> Bool -- this way you can write a different requirement
requiredOfX x = even x
dumbSum :: Int
dumbSum = dumbSum' 0 0 nats
where dumbSum' acc 15000000 _ = acc
dumbSum' acc count (x:xs)
| requiredOfX x = dumbSum' (acc + x) (count + 1) xs
| otherwise = dumbSum' acc (count + 1) xs

First, you can be clever as young Gauss was and compute the sum in O(1).
Fun stuff aside, your Haskell solution uses lists. I'm quite sure your C/C++ solution doesn't. (Haskell lists are very easy to use so one is tempted to use them even in cases where it might not be appropriate.) Try benchmarking this:
sumBy2 :: Integer -> Integer
sumBy2 = f 0
where
f result n | n <= 1 = result
| otherwise = f (n + result) (n - 2)
Compile it using GHC with -O2 argument. This function is tail-recursive so compiler can implement it very efficiently.
Update: If you want it using even function, it's possible:
sumBy2 :: Integer -> Integer
sumBy2 = f 0
where
f result n | n <= 0 = result
| even n = f (n + result) (n - 1)
| otherwise = f result (n - 1)
You can also easily make the filtering function a parameter:
sumFilter :: (Integral a) => (a -> Bool) -> a -> a
sumFilter filtfn = f 0
where
f result n | n <= 0 = result
| filtfn n = f (n + result) (n - 1)
| otherwise = f result (n - 1)

Strict version works much faster:
foldl' (+) 0 $ take 15000000 [2, 4..]

Another thing to note is that nats and evens are so-called Constant Applicative Forms, or CAFs for short. Basically, those correspond to top-level definitions without any arguments. CAFs are a bit of an odd duck, for instance being the reason for the Dreaded Monomorphism Restriction; I'm not sure the language definition even allows CAFs to be inlined.
In my mental model of how Haskell executes, by the time dumbSum returns a value, evens will be evaluated to look something like 2:4: ... : 30000000 : <thunk> and nats to 1:2: ... : 30000000 : <thunk>, where the <thunk>s represent something that's not been looked at yet. If my understanding is correct, these allocations of : do have to happen and can't be optimized away.
So one way of speeding things up without altering your code too much would be to simply write:
dumbSum :: Int
dumbSum = sum . take 15000000 . filter even $ [1..]
or
dumbSum = sum $ take 15000000 evens where
nats = [1..]
evens = filter even nats
On my machine, compiled with -O2, that alone seems to result in a roughly 30% speedup.
I'm no GHC connaisseur (I've never even profiled a Haskell program!), so I could be wildly off the mark, though.

Related

Alternative way to write a Haskell function of single input and output of this kind

Good morning everyone!
I'm using the following function as a fitting example of a function that needs to have a simple input and output. In this case it's a function that converts a number from decimal to binary form, as a list of digits no less, just because it is convenient later on.
I chose to write it like this, because even though a number goes in and a list comes out, another structure is needed as an intermediate step, that will hold the digits found so far and hold the quotient of the division, as input for the next step of the loop. I will clean up the necessary mess before outputing anything, though, by selecting the part of the structure that I'm interested in, in this case the second one , and not counters or other stuff, that I'm done with. (As I mentioned this is an example only, and it's not unusual in other cases to initialize the until loop with a triplet like (a,b,c), only to pick one of them at the end, as I see fit, with the help of additional function, like pickXof3.)
So there,
dec2Bin :: Int -> [Int]
dec2Bin num = snd $ until
(\(n,l) -> n <=0) -- test
(\(n,l) -> (fst $ division n, (snd $ division n):l)) -- main function
(num,[]) -- initialization
where division a = divMod a 2
I find it very convenient that Haskell, although lacking traditional for/while loops has a function like until, which reminds me very much of Mathematica's NextWhile, that I'm familiar with.
In the past I would write sth even uglier, like two functions, a "helper" one and a "main" one, like so
dec2BinHelper :: (Int,[Int]) -> (Int,[Int])
dec2BinHelper (n,l)
| n <= 0 = (n,l)
| otherwise = dec2BinHelper (fst $ division n, (snd $ division n):l)
where division a = divMod a 2
-- a function with the sole purpose to act as a front-end to the helper function, initializing its call parameters and picking up its output
dec2Bin :: Int -> [Int]
dec2Bin n = snd $ dec2BinHelper (n,[])
which I think is unnecessarily bloated.
Still, while the use of until allows me to define just one function, I get the feeling that it could be done even simpler/easier to read, perhaps in a way more fitting to functional programming. Is that so? How would you write such a function differently, while keeping the input and output at the absolutely essential values?
I strongly prefer your second solution. I'd start a clean-up with two things: use pattern matching, and use where to hide your helper functions. So:
dec2Bin :: Int -> [Int]
dec2Bin n = snd $ dec2BinHelper (n, []) where
dec2BinHelper (n, l)
| n <= 0 = (n, l)
| otherwise = dec2BinHelper (d, m:l)
where (d, m) = divMod n 2
Now, in the base case, you return a tuple; but then immediately call snd on it. Why not fuse the two?
dec2Bin :: Int -> [Int]
dec2Bin n = dec2BinHelper (n, []) where
dec2BinHelper (n, l)
| n <= 0 = l
| otherwise = dec2BinHelper (d, m:l)
where (d, m) = divMod n 2
There's no obvious reason why you should pass these arguments in a tuple, rather than as separate arguments, which is more idiomatic and saves some allocation/deallocation noise besides.
dec2Bin :: Int -> [Int]
dec2Bin n = dec2BinHelper n [] where
dec2BinHelper n l
| n <= 0 = l
| otherwise = dec2BinHelper d (m:l)
where (d, m) = divMod n 2
You can swap the arguments to dec2BinHelper and eta-reduce; that way, you will not be shadowing the definition of n.
dec2Bin :: Int -> [Int]
dec2Bin = dec2BinHelper [] where
dec2BinHelper l n
| n <= 0 = l
| otherwise = dec2BinHelper (m:l) d
where (d, m) = divMod n 2
Since you know that n > 0 in the recursive call, you can use the slightly faster quotRem in place of divMod. You could also consider using bitwise operations like (.&. 1) and shiftR 1; they may be even better, but you should benchmark to know for sure.
dec2Bin :: Int -> [Int]
dec2Bin = dec2BinHelper [] where
dec2BinHelper l n
| n <= 0 = l
| otherwise = dec2BinHelper (r:l) q
where (q, r) = quotRem n 2
When you don't have a descriptive name for your helper function, it's traditional to name it go or loop.
dec2Bin :: Int -> [Int]
dec2Bin = go [] where
go l n
| n <= 0 = l
| otherwise = go (r:l) q
where (q, r) = quotRem n 2
At this point, the two sides of the conditional are short enough that I'd be tempted to put them on their own line, though this is something of an aesthetic choice.
dec2Bin :: Int -> [Int]
dec2Bin = go [] where
go l n = if n <= 0 then l else go (r:l) q
where (q, r) = quotRem n 2
Finally, a comment on the name: the input isn't really in decimal in any meaningful sense. (Indeed, it's much more physically accurate to think of the input as already being in binary!) Perhaps int2Bin or something like that would be more accurate. Or let the type speak for itself, and just call it toBin.
toBin :: Int -> [Int]
toBin = go [] where
go l n = if n <= 0 then l else go (r:l) q
where (q, r) = quotRem n 2
At this point I'd consider this code quite idiomatic.

How do I memoize?

I have written this function that computes Collatz sequences, and I see wildly varying times of execution depending on the spin I give it. Apparently it is related to something called "memoization", but I have a hard time understanding what it is and how it works, and, unfortunately, the relevant article on HaskellWiki, as well as the papers it links to, have all proven to not be easily surmountable. They discuss intricate details of the relative performance of highly layman-indifferentiable tree constructions, while what I miss must be some very basic, very trivial point that these sources neglect to mention.
This is the code. It is a complete program, ready to be built and executed.
module Main where
import Data.Function
import Data.List (maximumBy)
size :: (Integral a) => a
size = 10 ^ 6
-- Nail the basics.
collatz :: Integral a => a -> a
collatz n | even n = n `div` 2
| otherwise = n * 3 + 1
recollatz :: Integral a => a -> a
recollatz = fix $ \f x -> if (x /= 1)
then f (collatz x)
else x
-- Now, I want to do the counting with a tuple monad.
mocollatz :: Integral b => b -> ([b], b)
mocollatz n = ([n], collatz n)
remocollatz :: Integral a => a -> ([a], a)
remocollatz = fix $ \f x -> if x /= 1
then f =<< mocollatz x
else return x
-- Trivialities.
collatzLength :: Integral a => a -> Int
collatzLength x = (length . fst $ (remocollatz x)) + 1
collatzPairs :: Integral a => a -> [(a, Int)]
collatzPairs n = zip [1..n] (collatzLength <$> [1..n])
longestCollatz :: Integral a => a -> (a, Int)
longestCollatz n = maximumBy order $ collatzPairs n
where
order :: Ord b => (a, b) -> (a, b) -> Ordering
order x y = snd x `compare` snd y
main :: IO ()
main = print $ longestCollatz size
With ghc -O2 it takes about 17 seconds, without ghc -O2 -- about 22 seconds to deliver the length and the seed of the longest Collatz sequence starting at any point below size.
Now, if I make these changes:
diff --git a/Main.hs b/Main.hs
index c78ad95..9607fe0 100644
--- a/Main.hs
+++ b/Main.hs
## -1,6 +1,7 ##
module Main where
import Data.Function
+import qualified Data.Map.Lazy as M
import Data.List (maximumBy)
size :: (Integral a) => a
## -22,10 +23,15 ## recollatz = fix $ \f x -> if (x /= 1)
mocollatz :: Integral b => b -> ([b], b)
mocollatz n = ([n], collatz n)
-remocollatz :: Integral a => a -> ([a], a)
-remocollatz = fix $ \f x -> if x /= 1
- then f =<< mocollatz x
- else return x
+remocollatz :: (Num a, Integral b) => b -> ([b], a)
+remocollatz 1 = return 1
+remocollatz x = case M.lookup x (table mutate) of
+ Nothing -> mutate x
+ Just y -> y
+ where mutate x = remocollatz =<< mocollatz x
+
+table :: (Ord a, Integral a) => (a -> b) -> M.Map a b
+table f = M.fromList [ (x, f x) | x <- [1..size] ]
-- Trivialities.
-- Then it will take just about 4 seconds with ghc -O2, but I would not live long enough to see it complete without ghc -O2.
Looking at the details of cost centres with ghc -prof -fprof-auto -O2 reveals that the first version enters collatz about a hundred million times, while the patched one -- just about one and a half million times. This must be the reason of the speedup, but I have a hard time understanding the inner workings of this magic. My best idea is that we replace a portion of expensive recursive calls with O(log n) map lookups, but I don't know if it's true and why it depends so much on some godforsaken compiler flags, while, as I see it, such performance swings should all follow solely from the language.
Can I haz an explanation of what happens here, and why the performance differs so vastly between ghc -O2 and plain ghc builds?
P.S. There are two requirements to the achieving of automagical memoization highlighted elsewhere on Stack Overflow:
Make a function to be memoized a top-level name.
Make a function to be memoized a monomorphic one.
In line with these requirements, I rebuilt remocollatz as follows:
remocollatz :: Int -> ([Int], Int)
remocollatz 1 = return 1
remocollatz x = mutate x
mutate :: Int -> ([Int], Int)
mutate x = remocollatz =<< mocollatz x
Now it's as top level and as monomorphic as it gets. Running time is about 11 seconds, versus the similarly monomorphized table version:
remocollatz :: Int -> ([Int], Int)
remocollatz 1 = return 1
remocollatz x = case M.lookup x (table mutate) of
Nothing -> mutate x
Just y -> y
mutate :: Int -> ([Int], Int)
mutate = \x -> remocollatz =<< mocollatz x
table :: (Int -> ([Int], Int)) -> M.Map Int ([Int], Int)
table f = M.fromList [ (x, f x) | x <- [1..size] ]
-- Running in less than 4 seconds.
I wonder why the memoization ghc is supposedly performing in the first case here is almost 3 times slower than my dumb table.
Can I haz an explanation of what happens here, and why the performance differs so vastly between ghc -O2 and plain ghc builds?
Disclaimer: this is a guess, not verified by viewing GHC core output. A careful answer would do so to verify the conjectures outlined below. You can try peering through it yourself: add -ddump-simpl to your compilation line and you will get copious output detailing exactly what GHC has done to your code.
You write:
remocollatz x = {- ... -} table mutate {- ... -}
where mutate x = remocollatz =<< mocollatz x
The expression table mutate in fact does not depend on x; but it appears on the right-hand side of an equation that takes x as an argument. Consequently, without optimizations, this table is recomputed each time remocollatz is called (presumably even from inside the computation of table mutate).
With optimizations, GHC notices that table mutate does not depend on x, and floats it to its own definition, effectively producing:
fresh_variable_name = table mutate
where mutate x = remocollatz =<< mocollatz x
remocollatz x = case M.lookup x fresh_variable_name of
{- ... -}
The table is therefore computed just once for the entire program run.
don't know why it [the performance] depends so much on some godforsaken compiler flags, while, as I see it, such performance swings should all follow solely from the language.
Sorry, but Haskell doesn't work that way. The language definition tells clearly what the meaning of a given Haskell term is, but does not say anything about the runtime or memory performance needed to compute that meaning.
Another approach to memoization that works in some situations, like this one, is to use a boxed vector, whose elements are computed lazily. The function used to initialize each element can use other elements of the vector in its calculation. As long as the evaluation of an element of the vector doesn't loop and refer to itself, just the elements it recursively depends on will be evaluated. Once evaluated, an element is effectively memoized, and this has the further benefit that elements of the vector that are never referenced are never evaluated.
The Collatz sequence is a nearly ideal application for this technique, but there is one complication. The next Collatz value(s) in sequence from a value under the limit may be outside the limit, which would cause a range error when indexing the vector. I solved this by just iterating through the sequence until back under the limit and counting the steps to do so.
The following program takes 0.77 seconds to run unoptimized and 0.30 when optimized:
import qualified Data.Vector as V
limit = 10 ^ 6 :: Int
-- The Collatz function, which given a value returns the next in the sequence.
nextCollatz val
| odd val = 3 * val + 1
| otherwise = val `div` 2
-- Given a value, return the next Collatz value in the sequence that is less
-- than the limit and the number of steps to get there. For example, the
-- sequence starting at 13 is: [13, 40, 20, 10, 5, 16, 8, 4, 2, 1], so if
-- limit is 100, then (nextCollatzWithinLimit 13) is (40, 1), but if limit is
-- 15, then (nextCollatzWithinLimit 13) is (10, 3).
nextCollatzWithinLimit val = (firstInRange, stepsToFirstInRange)
where
firstInRange = head rest
stepsToFirstInRange = 1 + (length biggerThanLimit)
(biggerThanLimit, rest) = span (>= limit) (tail collatzSeqStartingWithVal)
collatzSeqStartingWithVal = iterate nextCollatz val
-- A boxed vector holding Collatz length for each index. The collatzFn used
-- to generate the value for each element refers back to other elements of
-- this vector, but since the vector elements are only evaluated as needed and
-- there aren't any loops in the Collatz sequences, the values are calculated
-- only as needed.
collatzVec :: V.Vector Int
collatzVec = V.generate limit collatzFn
where
collatzFn :: Int -> Int
collatzFn index
| index <= 1 = 1
| otherwise = (collatzVec V.! nextWithinLimit) + stepsToGetThere
where
(nextWithinLimit, stepsToGetThere) = nextCollatzWithinLimit index
main :: IO ()
main = do
-- Use a fold through the vector to find the longest Collatz sequence under
-- the limit, and keep track of both the maximum length and the initial
-- value of the sequence, which is the index.
let (maxLength, maxIndex) = V.ifoldl' accMaxLen (0, 0) collatzVec
accMaxLen acc#(accMaxLen, accMaxIndex) index currLen
| currLen <= accMaxLen = acc
| otherwise = (currLen, index)
putStrLn $ "Max Collatz length below " ++ show limit ++ " is "
++ show maxLength ++ " at index " ++ show maxIndex

Implementing factorial and fibonacci using State monad (as a learning exercise)

I worked my way through Mike Vanier's monad tutorial (which is excellent) and I'm working on a few of the exercises in his post on how to use a "State" monad.
In particular, he suggests an exercise which consists of writing functions for factorial and fibonacci using a State monad. I gave it a shot and came up with the answers below. (I find do notation pretty confusing, hence my choice of syntax).
Neither of my implementations look particularly "Haskell-y" and, in the interest of not internalizing bad practices, I thought I'd ask folks for input on how they would've gone about implementing these functions (using the state monad). Is it possibly to write this code far more simply (aside from switching to do notation)? I strongly suspect this is the case.
I'm aware that it's a bit impractical to use a state monad for this purpose but this is purely a learning exercise - pun most certainly intended.
That said, the performance is not that much worse: in order to calc the factorial of 100000 (the answer is ~21k digits long), the unfoldr version took ~1.2 sec (in GHCi) vs. ~1.5 sec for the state monad version.
import Control.Monad.State (State, get, put, evalState)
import Data.List (unfoldr)
fibonacci :: Integer -> Integer
fibonacci 0 = 0
fibonacci n = evalState fib_state (1,0,1,n)
fib_state :: State (Integer,Integer,Integer,Integer) Integer
fib_state = get >>=
\s ->
let (p1,p2,ctr,n) = s
in case compare ctr n of
LT -> put (p1+p2, p1, ctr+1, n) >> fib_state
_ -> return p1
factorial :: Integer -> Integer
factorial n = evalState fact_state (n,1)
fact_state :: State (Integer,Integer) Integer
fact_state = get >>=
\s ->
let (n,f) = s
in case n of
0 -> return f
_ -> put (n-1,f*n) >> fact_state
-------------------------------------------------------------------
--Functions below are used only to test output of functions above
factorial' :: Integer -> Integer
factorial' n = product [1..n]
fibonacci' :: Int -> Integer
fibonacci' 0 = 1
fibonacci' 1 = 1
fibonacci' n =
let getFst (a,b,c) = a
in getFst
$ last
$ unfoldr (\(p1,p2,cnt) ->
if cnt == n
then Nothing
else Just ((p1,p2,cnt)
,(p1+p2,p1,cnt+1))
) (1,1,1)
Your functions seem to be a bit more complicated than they need to be, but you have the right idea. For the factorial, all you need to keep track of is the current number you're multiplying by and the number that you've accumulated so far. So, we'll say that State Int Int is a computation that operates on the current number on the state and returns the number that you've multiplied up until now:
fact_state :: State Int Int
fact_state = get >>= \x -> if x <= 1
then return 1
else (put (x - 1) >> fmap (*x) fact_state)
factorial :: Int -> Int
factorial = evalState fact_state
Prelude Control.Monad.State.Strict Control.Applicative> factorial <$> [1..10]
[1,2,6,24,120,720,5040,40320,362880,3628800]
The fibonacci sequence is similar. You need to keep the last two numbers in order to know what you're going to be adding together, and how far you've gone so far:
fibs_state :: State (Int, Int, Int) Int
fibs_state = get >>= \(x1, x2, n) -> if n == 0
then return x1
else (put (x2, x1+x2, n-1) >> fibs_state)
fibonacci n = evalState fibs_state (0, 1, n)
Prelude Control.Monad.State.Strict Control.Applicative> fibonacci <$> [1..10]
[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]
Two stylistic suggestions:
\s ->
let (p1,p2,ctr,n) = s
in ...
is equivalent to:
\(p1,p2,ctr,n) -> ...
and your case statement for fib_state may be written with an if statement:
if ctr < n
then put (p1+p2, p1, ctr+1, n) >> fib_state
else return p1

Euler #4 with bigger domain

Consider the modified Euler problem #4 -- "Find the maximum palindromic number which is a product of two numbers between 100 and 9999."
rev :: Int -> Int
rev x = rev' x 0
rev' :: Int -> Int -> Int
rev' n r
| n == 0 = r
| otherwise = rev' (n `div` 10) (r * 10 + n `mod` 10)
pali :: Int -> Bool
pali x = x == rev x
main :: IO ()
main = print . maximum $ [ x*y | x <- nums, y <- nums, pali (x*y)]
where
nums = [9999,9998..100]
This Haskell solution using -O2 and ghc 7.4.1 takes about 18
seconds.
The similar C solution takes 0.1 second.
So Haskell is 180 times
slower. What's wrong with my solution? I assume that this type of
problems Haskell solves pretty well.
Appendix - analogue C solution:
#define A 100
#define B 9999
int ispali(int n)
{
int n0=n, k=0;
while (n>0) {
k = 10*k + n%10;
n /= 10;
}
return n0 == k;
}
int main(void)
{
int max = 0;
for (int i=B; i>=A; i--)
for (int j=B; j>=A; j--) {
if (i*j > max && ispali(i*j))
max = i*j; }
printf("%d\n", max);
}
The similar C solution
That is a common misconception.
Lists are not loops!
And using lists to emulate loops has performance implications unless the compiler is able to eliminate the list from the code.
If you want to compare apples to apples, write the Haskell structure more or less equivalent to a loop, a tail recursive worker (with strict accumulator, though often the compiler is smart enough to figure out the strictness by itself).
Now let's take a more detailed look. For comparison, the C, compiled with gcc -O3, takes ~0.08 seconds here, the original Haskell, compiled with ghc -O2 takes ~20.3 seconds, with ghc -O2 -fllvm ~19.9 seconds. Pretty terrible.
One mistake in the original code is to use div and mod. The C code uses the equivalent of quot and rem, which map to the machine division instructions and are faster than div and mod. For positive arguments, the semantics are the same, so whenever you know that the arguments are always non-negative, never use div and mod.
Changing that, the running time becomes ~15.4 seconds when compiling with the native code generator, and ~2.9 seconds when compiling with the LLVM backend.
The difference is due to the fact that even the machine division operations are quite slow, and LLVM replaces the division/remainder with a multiply-and-shift operation. Doing the same by hand for the native backend (actually, a slightly better replacement taking advantage of the fact that I know the arguments will always be non-negative) brings its time down to ~2.2 seconds.
We're getting closer, but are still a far cry from the C.
That is due to the lists. The code still builds a list of palindromes (and traverses a list of Ints for the two factors).
Since lists cannot contain unboxed elements, that means there is a lot of boxing and unboxing going on in the code, that takes time.
So let us eliminate the lists, and take a look at the result of translating the C to Haskell:
module Main (main) where
a :: Int
a = 100
b :: Int
b = 9999
ispali :: Int -> Bool
ispali n = go n 0
where
go 0 acc = acc == n
go m acc = go (m `quot` 10) (acc * 10 + (m `rem` 10))
maxpal :: Int
maxpal = go 0 b
where
go mx i
| i < a = mx
| otherwise = go (inner mx b) (i-1)
where
inner m j
| j < a = m
| p > m && ispali p = inner p (j-1)
| otherwise = inner m (j-1)
where
p = i*j
main :: IO ()
main = print maxpal
The nested loop is translated to two nested worker functions, we use an accumulator to store the largest palindrome found so far. Compiled with ghc -O2, that runs in ~0.18 seconds, with ghc -O2 -fllvm it runs in ~0.14 seconds (yes, LLVM is better at optimising loops than the native code generator).
Still not quite there, but a factor of about 2 isn't too bad.
Maybe some find the following where the loop is abstracted out more readable, the generated core is for all intents and purposes identical (modulo a switch of argument order), and the performance of course the same:
module Main (main) where
a :: Int
a = 100
b :: Int
b = 9999
ispali :: Int -> Bool
ispali n = go n 0
where
go 0 acc = acc == n
go m acc = go (m `quot` 10) (acc * 10 + (m `rem` 10))
downto :: Int -> Int -> a -> (a -> Int -> a) -> a
downto high low acc fun = go high acc
where
go i acc
| i < low = acc
| otherwise = go (i-1) (fun acc i)
maxpal :: Int
maxpal = downto b a 0 $ \m i ->
downto b a m $ \mx j ->
let p = i*j
in if mx < p && ispali p then p else mx
main :: IO ()
main = print maxpal
#axblount is at least partly right; the following modification makes the program run almost three times as fast as the original:
maxPalindrome = foldl f 0
where f a x | x > a && pali x = x
| otherwise = a
main :: IO ()
main = print . maxPalindrome $ [x * y | x <- nums, y <- nums]
where nums = [9999,9998..100]
That still leaves a factor 60 slowdown, though.
This is more true to what the C code is doing:
maxpali :: [Int] -> Int
maxpali xs = go xs 0
where
go [] m = m
go (x:xs) m = if x > m && pali(x) then go xs x else go xs m
main :: IO()
main = print . maxpali $ [ x*y | x <- nums, y <- nums ]
where nums = [9999,9998..100]
On my box this takes 2 seconds vs .5 for the C version.
Haskell may be storing that entire list [ x*y | x <- nums, y <- nums, pali (x*y)] where as the C solution calculates the maximum on the fly. I'm not sure about this.
Also the C solution will only calculate ispali if the product beats the previous maximum. I would bet Haskell calculates are palindrome products regardless of whether x*y is a possible max.
It seems to me that you are having a branch prediction problem. In the C code, you have two nested loops and as soon as a palindrome is seen in the inner loop, the rest of the inner loop will be skipped very fast.
The way you feed this list of products instead of the nested loops I am not sure that ghc is doing any of this prediction.
Another way to write this is to use two folds, instead of one fold over the flattened list:
-- foldl g0 0 [x*y | x<-[b-1,b-2..a], y<-[b-1,b-2..a], pali(x*y)] (A)
-- foldl g1 0 [x*y | x<-[b-1,b-2..a], y<-[b-1,b-2..a]] (B)
-- foldl g2 0 [ [x*y | y<-[b-1,b-2..a]] | x<-[b-1,b-2..a]] (C)
maxpal b a = foldl f1 0 [b-1,b-2..a] -- (D)
where
f1 m x = foldl f2 m [b-1,b-2..a]
where
f2 m y | p>m && pali p = p
| otherwise = m
where p = x*y
main = print $ maxpal 10000 100
Seems to run much faster than (B) (as in larsmans's answer), too (only 3x - 4x slower then the following loops-based code). Fusing foldl and enumFromThenTo definitions gets us the "functional loops" code (as in DanielFischer's answer),
maxpal_loops b a = f (b-1) 0 -- (E)
where
f x m | x < a = m
| otherwise = g (b-1) m
where
g y m | y < a = f (x-1) m
| p>m && pali p = g (y-1) p
| otherwise = g (y-1) m
where p = x*y
The (C) variant is very suggestive of further algorithmic improvements (that's outside the scope of the original Q of course) that exploit the hidden order in the lists, destroyed by the flattening:
{- foldl g2 0 [ [x*y | y<-[b-1,b-2..a]] | x<-[b-1,b-2..a]] (C)
foldl g2 0 [ [x*y | y<-[x, x-1..a]] | x<-[b-1,b-2..a]] (C1)
foldl g0 0 [ safehead 0 . filter pali $
[x*y | y<-[x, x-1..a]] | x<-[b-1,b-2..a]] (C2)
fst $ until ... (\(m,s)-> (max m .
safehead 0 . filter pali . takeWhile (> m) $
head s, tail s))
(0,[ [x*y | y<-[x, x-1..a]] | x<-[b-1,b-2..a]]) (C3)
safehead 0 $ filter pali $ mergeAllDescending
[ [x*y | y<-[x, x-1..a]] | x<-[b-1,b-2..a]] (C4)
-}
(C3) can stop as soon as the head x*y in a sub-list is smaller than the currently found maximum. It is what short-cutting functional loops code could achieve, but not (C4), which is guaranteed to find the maximal palindromic number first. Plus, for list-based code its algorithmic nature is more visually apparent, IMO.

What can be improved on my first haskell program?

Here is my first Haskell program. What parts would you write in a better way?
-- Multiplication table
-- Returns n*n multiplication table in base b
import Text.Printf
import Data.List
import Data.Char
-- Returns n*n multiplication table in base b
mulTable :: Int -> Int -> String
mulTable n b = foldl (++) (verticalHeader n b w) (map (line n b w) [0..n])
where
lo = 2* (logBase (fromIntegral b) (fromIntegral n))
w = 1+fromInteger (floor lo)
verticalHeader :: Int -> Int -> Int -> String
verticalHeader n b w = (foldl (++) tableHeader columnHeaders)
++ "\n"
++ minusSignLine
++ "\n"
where
tableHeader = replicate (w+2) ' '
columnHeaders = map (horizontalHeader b w) [0..n]
minusSignLine = concat ( replicate ((w+1)* (n+2)) "-" )
horizontalHeader :: Int -> Int -> Int -> String
horizontalHeader b w i = format i b w
line :: Int -> Int -> Int -> Int -> String
line n b w y = (foldl (++) ((format y b w) ++ "|" )
(map (element b w y) [0..n])) ++ "\n"
element :: Int -> Int -> Int -> Int -> String
element b w y x = format (y * x) b w
toBase :: Int -> Int -> [Int]
toBase b v = toBase' [] v where
toBase' a 0 = a
toBase' a v = toBase' (r:a) q where (q,r) = v `divMod` b
toAlphaDigits :: [Int] -> String
toAlphaDigits = map convert where
convert n | n < 10 = chr (n + ord '0')
| otherwise = chr (n + ord 'a' - 10)
format :: Int -> Int -> Int -> String
format v b w = concat spaces ++ digits ++ " "
where
digits = if v == 0
then "0"
else toAlphaDigits ( toBase b v )
l = length digits
spaceCount = if (l > w) then 0 else (w-l)
spaces = replicate spaceCount " "
Here are some suggestions:
To make the tabularity of the computation more obvious, I would pass the list [0..n] to the line function rather than passing n.
I would further split out the computation of the horizontal and vertical axes so that they are passed as arguments to mulTable rather than computed there.
Haskell is higher-order, and almost none of the computation has to do with multiplication. So I would change the name of mulTable to binopTable and pass the actual multiplication in as a parameter.
Finally, the formatting of individual numbers is repetitious. Why not pass \x -> format x b w as a parameter, eliminating the need for b and w?
The net effect of the changes I am suggesting is that you build a general higher-order function for creating tables for binary operators. Its type becomes something like
binopTable :: (i -> String) -> (i -> i -> i) -> [i] -> [i] -> String
and you wind up with a much more reusable function—for example, Boolean truth tables should be a piece of cake.
Higher-order and reusable is the Haskell Way.
You don't use anything from import Text.Printf.
Stylistically, you use more parentheses than necessary. Haskellers tend to find code more readable when it's cleaned of extraneous stuff like that. Instead of something like h x = f (g x), write h = f . g.
Nothing here really requires Int; (Integral a) => a ought to do.
foldl (++) x xs == concat $ x : xs: I trust the built-in concat to work better than your implementation.
Also, you should prefer foldr when the function is lazy in its second argument, as (++) is – because Haskell is lazy, this reduces stack space (and also works on infinite lists).
Also, unwords and unlines are shortcuts for intercalate " " and concat . map (++ "\n") respectively, i.e. "join with spaces" and "join with newlines (plus trailing newline)"; you can replace a couple things by those.
Unless you use big numbers, w = length $ takeWhile (<= n) $ iterate (* b) 1 is probably faster. Or, in the case of a lazy programmer, let w = length $ toBase b n.
concat ( (replicate ((w+1)* (n+2)) "-" ) == replicate ((w+1) * (n+2)) '-' – not sure how you missed this one, you got it right just a couple lines up.
You do the same thing with concat spaces, too. However, wouldn't it be easier to actually use the Text.Printf import and write printf "%*s " w digits?
Norman Ramsey gave excellent high-level (design) suggestions; Below are some low-level ones:
First, consult with HLint. HLint is a friendly program that gives you rudimentary advice on how to improve your Haskell code!
In your case HLint gives 7 suggestions. (mostly about redundant brackets)
Modify your code according to HLint's suggestions until it likes what you feed it.
More HLint-like stuff:
concat (replicate i "-"). Why not replicate i '-'?
Consult with Hoogle whenever there is reason to believe that a function you need is already available in Haskell's libraries. Haskell comes with tons of useful functions so Hoogle should come in handy quite often.
Need to concatenate strings? Search for [String] -> String, and voila you found concat. Now go replace all those folds.
The previous search also suggested unlines. Actually, this even better suits your needs. It's magic!
Optional: pause and thank in your heart to Neil M for making Hoogle and HLint, and thank others for making other good stuff like Haskell, bridges, tennis balls, and sanitation.
Now, for every function that takes several arguments of the same type, make it clear which means what, by giving them descriptive names. This is better than comments, but you can still use both.
So
-- Returns n*n multiplication table in base b
mulTable :: Int -> Int -> String
mulTable n b =
becomes
mulTable :: Int -> Int -> String
mulTable size base =
To soften the extra characters blow of the previous suggestion: When a function is only used once, and is not very useful by itself, put it inside its caller's scope in its where clause, where it could use the callers' variables, saving you the need to pass everything to it.
So
line :: Int -> Int -> Int -> Int -> String
line n b w y =
concat
$ format y b w
: "|"
: map (element b w y) [0 .. n]
element :: Int -> Int -> Int -> Int -> String
element b w y x = format (y * x) b w
becomes
line :: Int -> Int -> Int -> Int -> String
line n b w y =
concat
$ format y b w
: "|"
: map element [0 .. n]
where
element x = format (y * x) b w
You can even move line into mulTable's where clause; imho, you should.
If you find a where clause nested inside another where clause troubling, then I suggest to change your indentation habits. My recommendation is to use consistent indentation of always 2 or always 4 spaces. Then you can easily see, everywhere, where the where in the other where is at. ok
Below's what it looks like (with a few other changes in style):
import Data.List
import Data.Char
mulTable :: Int -> Int -> String
mulTable size base =
unlines $
[ vertHeaders
, minusSignsLine
] ++ map line [0 .. size]
where
vertHeaders =
concat
$ replicate (cellWidth + 2) ' '
: map horizontalHeader [0 .. size]
horizontalHeader i = format i base cellWidth
minusSignsLine = replicate ((cellWidth + 1) * (size + 2)) '-'
cellWidth = length $ toBase base (size * size)
line y =
concat
$ format y base cellWidth
: "|"
: map element [0 .. size]
where
element x = format (y * x) base cellWidth
toBase :: Integral i => i -> i -> [i]
toBase base
= reverse
. map (`mod` base)
. takeWhile (> 0)
. iterate (`div` base)
toAlphaDigit :: Int -> Char
toAlphaDigit n
| n < 10 = chr (n + ord '0')
| otherwise = chr (n + ord 'a' - 10)
format :: Int -> Int -> Int -> String
format v b w =
spaces ++ digits ++ " "
where
digits
| v == 0 = "0"
| otherwise = map toAlphaDigit (toBase b v)
spaces = replicate (w - length digits) ' '
0) add a main function :-) at least rudimentary
import System.Environment (getArgs)
import Control.Monad (liftM)
main :: IO ()
main = do
args <- liftM (map read) $ getArgs
case args of
(n:b:_) -> putStrLn $ mulTable n b
_ -> putStrLn "usage: nntable n base"
1) run ghc or runhaskell with -Wall; run through hlint.
While hlint doesn't suggest anything special here (only some redundant brackets), ghc will tell you that you don't actually need Text.Printf here...
2) try running it with base = 1 or base = 0 or base = -1
If you want multiline comments use:
{- A multiline
comment -}
Also, never use foldl, use foldl' instead, in cases where you are dealing with large lists which must be folded. It is more memory efficient.
A brief comments saying what each function does, its arguments and return value, is always good. I had to read the code pretty carefully to fully make sense of it.
Some would say if you do that, explicit type signatures may not be required. That's an aesthetic question, I don't have a strong opinion on it.
One minor caveat: if you do remove the type signatures, you'll automatically get the polymorphic Integral support ephemient mentioned, but you will still need one around toAlphaDigits because of the infamous "monomorphism restriction."

Resources