As part of a school project I'm implementing some crypthographic algorithms in Haskell. As you probably know this involves quite a lot of low level bit fiddling. Now I am stuck on one particular sub routine which causes me a headache. The routine, which is a permutation on 256 bits, works as follows:
Input: a 256 bit block.
Then all the even numbered bits (0,2,...) in the input block are taken to be the first 128 bits in the output block. While the odd numbered bits are taken to be the 128 last bits in the output block. More specifically, the formula for the i'th bit in the output is given as (ai is the i'th bit in the input block, and b is the output):
bi = a2i
bi+2d-1 = a2i + 1
for i from 0 to 2d-1-1, d = 8.
As a toy example, assume we used a reduced version of the routine which worked with 16 bit blocks instead of 256 bits. Then the following bitstring would be permuted as follows:
1010 1010 1010 1010 -> 1111 1111 0000 0000
I have not been able to come up with a clean implementation for this function. In particular I have been trying with a ByteString -> ByteString signature, but that sort of forces me to work on a Word8 kind of granularity. But each byte in the output bytestring is a function of bits in all the other bytes, which requires some really messy operations.
I will be really grateful for any kind of hint or advice on how to approach this problem.
If you want an efficient implementation, I don't think you can avoid working with bytes. Here is an example solution. It assumes that there is always an even number of bytes in the ByteString. I'm not very familiar with unboxing or strictness tweaking, but I think these would be necessary if you want to be very efficient.
import Data.ByteString (pack, unpack, ByteString)
import Data.Bits
import Data.Word
-- the main attraction
packString :: ByteString -> ByteString
packString = pack . packWords . unpack
-- main attraction equivalent, in [Word8]
packWords :: [Word8] -> [Word8]
packWords ws = evenPacked ++ unevenPacked
where evenBits = map packEven ws
unevenBits = map packUneven ws
evenPacked = consumePairs packNibbles evenBits
unevenPacked = consumePairs packNibbles unevenBits
-- combines 2 low nibbles (first 4 bytes) into a (high nibble, low nibble) word
-- assumes that only the low nibble of both arguments can be non-zero.
packNibbles :: Word8 -> Word8 -> Word8
packNibbles w1 w2 = (shiftL w1 4) .|. w2
packEven w = packBits w [0, 2, 4, 6]
packUneven w = packBits w [1, 3, 5, 7]
-- packBits 254 [0, 2, 4, 6] = 14
-- packBits 254 [1, 3, 5, 7] = 15
packBits :: Word8 -> [Int] -> Word8
packBits w is = foldr (.|.) 0 $ map (packBit w) is
-- packBit 255 0 = 1
-- packBit 255 1 = 1
-- packBit 255 2 = 2
-- packBit 255 3 = 2
-- packBit 255 4 = 4
-- packBit 255 5 = 4
-- packBit 255 6 = 8
-- packBit 255 7 = 8
packBit :: Word8 -> Int -> Word8
packBit w i = shiftR (w .&. 2^i) ((i `div` 2) + (i `mod` 2))
-- sort of like map, but halves the list in size by consuming two elements.
-- Is there a clearer way to write this with built-in function?
consumePairs :: (a -> a -> b) -> [a] -> [b]
consumePairs f (x : x' : xs) = f x x' : consumePairs f xs
consumePairs _ [] = []
consumePairs _ _ = error "list must contain even number of elements"
this should work:
import Data.List
import Data.Function
map fst $ sortBy (compare `on` snd) $ zip yourList $ cycle [0,1]
A bit of explanation:
As sortBy preserve the original order, we can pair each value at an even position a "0" and each value at an odd position a "1", then we simply sort on the second value of the pair. So all values at even positions will be placed before the values at odd positions but their order will be kept.
Chris
Unless performance is critical, I'd recommend using a bit vector representation for a project like this. As you've discovered, randomly accessing individual bits is something of a pain when they're in packed form, but Data.Vector provides a wealth of functions for these sorts of tasks.
import Data.Bits
import qualified Data.Vector as V
type BitVector = V.Vector Bool
unpack :: (Bits a) => a -> BitVector
unpack w = V.generate (bitSize w) (testBit w)
pack :: (Bits a) => BitVector -> a
pack v = V.ifoldl' set 0 v
where
set w i True = w `setBit` i
set w _ _ = w
mkPermutationVector :: Int -> V.Vector Int
mkPermutationVector d = V.generate (2^d) b
where
b i | i < 2^(d-1) = 2*i
| otherwise = let i' = i-2^(d-1)
in 2*i'+1
permute :: Int -> BitVector -> BitVector
permute d v = V.backpermute v (mkPermutationVector d)
Notice how this lets you specify the permutation by closely transcribing the mathematical description. This substantially reduces the likelihood of errors, and is more pleasant to write than bit-twiddly code.
To test with your example vector (in base 10):
*Main> import Data.Word
*Main Data.Word> let permute16 = pack . permute 4 . unpack :: Word16 -> Word16
*Main Data.Word> permute16 43690
65280
Now, by moving to bit vectors as your representation, you lose a lot of what you get for free by using Haskell types, such as Num instances. However, you can always implement the Num operations for your representation; here's a start:
plus :: BitVector -> BitVector -> BitVector
plus as bs = V.tail sums
where
(sums, carries) = V.unzip sumsAndCarries
sumsAndCarries = V.scanl' fullAdd (False, False) (V.zip as bs)
fullAdd (_, cin) (a, b) = ((a /= b) /= cin
, (a && b) || (cin && (a /= b)))
You may also find Levent Erkok's sbv package useful, although I'm not sure it exposes a function as convenient as backpermute for your particular question.
Update: I thought this was a fun question to answer, so I went ahead and fleshed the code out a bit as a library: bit-vector.
Related
i'm trying to write a function that for n gives matrix n*n with unique rows and columns (latin square).
I got function that gives my list of strings "1" .. "2" .. "n"
numSymbol:: Int -> [String]
I tried to generate all permutations of this, and them all n-length tuples of permutations, and them check if it is unique in row / columns. But complexity (n!)^2 works perfect for 2 and 3, but with n > 3 it takes forever. It is possible to build latin square from permutations directly, for example from
permutation ( numSymbol 3) = [["1","2","3"],["1","3","2"],["2","1","3"],["2","3","1"],["3","1","2"],["3","2","1"]]
get
[[["1","2","3",],["2","1","3"],["3","1","2"]] , ....]
without generating list like [["1",...],["1",...],...], when we know first element disqualify it ?
Note: since we can easily take a Latin square that's been filled with numbers from 1 to n and re-label it with anything we want, we can write code that uses integer symbols without giving anything away, so let's stick with that.
Anyway, the stateful backtracking/nondeterministic monad:
type StateList s = StateT s []
is helpful for this sort of problem.
Here's the idea. We know that every symbol s is going to appear exactly once in each row r, so we can represent this with an urn of all possible ordered pairs (r,s):
my_rs_urn = [(r,s) | r <- [1..n], s <- [1..n]]
Similarly, as every symbol s appears exactly once in each column c, we can use a second urn:
my_cs_urn = [(c,s) | c <- [1..n], s <- [1..n]]
Creating a Latin square is matter of filling in each position (r,c) with a symbol s by removing matching balls (r,s) and (c,s) (i.e., removing two balls, one from each urn) so that every ball is used exactly once. Our state will be the content of the urns.
We need backtracking because we might reach a point where for a particular position (r,c), there is no s such that (r,s) and (c,s) are both still available in their respective urns. Also, a pleasant side-effect of list-based backtracking/nondeterminism is that it'll generate all possible Latin squares, not just the first one it finds.
Given this, our state will look like:
type Urn = [(Int,Int)]
data S = S
{ size :: Int
, rs :: Urn
, cs :: Urn }
I've included the size in the state for convenience. It won't ever be modified, so it actually ought to be in a Reader instead, but this is simpler.
We'll represent a square by a list of cell contents in row-major order (i.e., the symbols in positions [(1,1),(1,2),...,(1,n),(2,1),...,(n,n)]):
data Square = Square
Int -- square size
[Int] -- symbols in row-major order
deriving (Show)
Now, the monadic action to generate latin squares will look like this:
type M = StateT S []
latin :: M Square
latin = do
n <- gets size
-- for each position (r,c), get a valid symbol `s`
cells <- forM (pairs n) (\(r,c) -> getS r c)
return $ Square n cells
pairs :: Int -> [(Int,Int)]
pairs n = -- same as [(x,y) | x <- [1..n], y <- [1..n]]
(,) <$> [1..n] <*> [1..n]
The worker function getS picks an s so that (r,s) and (c,s) are available in the respective urns, removing those pairs from the urns as a side effect. Note that getS is written non-deterministically, so it'll try every possible way of picking an s and associated balls from the urns:
getS :: Int -> Int -> M Int
getS r c = do
-- try each possible `s` in the row
s <- pickSFromRow r
-- can we put `s` in this column?
pickCS c s
-- if so, `s` is good
return s
Most of the work is done by the helpers pickSFromRow and pickCS. The first, pickSFromRow picks an s from the given row:
pickSFromRow :: Int -> M Int
pickSFromRow r = do
balls <- gets rs
-- "lift" here non-determinstically picks balls
((r',s), rest) <- lift $ choices balls
-- only consider balls in matching row
guard $ r == r'
-- remove the ball
modify (\st -> st { rs = rest })
-- return the candidate "s"
return s
It uses a choices helper which generates every possible way of pulling one element out of a list:
choices :: [a] -> [(a,[a])]
choices = init . (zipWith f <$> inits <*> tails)
where f a (x:b) = (x, a++b)
f _ _ = error "choices: internal error"
The second, pickCS checks if (c,s) is available in the cs urn, and removes it if it is:
pickCS :: Int -> Int -> M ()
pickCS c s = do
balls <- gets cs
-- only continue if the required ball is available
guard $ (c,s) `elem` balls
-- remove the ball
modify (\st -> st { cs = delete (c,s) balls })
With an appropriate driver for our monad:
runM :: Int -> M a -> [a]
runM n act = evalStateT act (S n p p)
where p = pairs n
this can generate all 12 Latin square of size 3:
λ> runM 3 latin
[Square 3 [1,2,3,2,3,1,3,1,2],Square 3 [1,2,3,3,1,2,2,3,1],...]
or the 576 Latin squares of size 4:
λ> length $ runM 4 latin
576
Compiled with -O2, it's fast enough to enumerate all 161280 squares of size 5 in a couple seconds:
main :: IO ()
main = print $ length $ runM 5 latin
The list-based urn representation above isn't very efficient. On the other hand, because the lengths of the lists are pretty small, there's not that much to be gained by finding more efficient representations.
Nonetheless, here's complete code that uses efficient Map/Set representations tailored to the way the rs and cs urns are used. Compiled with -O2, it runs in constant space. For n=6, it can process about 100000 Latin squares per second, but that still means it'll need to run for a few hours to enumerate all 800 million of them.
{-# OPTIONS_GHC -Wall #-}
module LatinAll where
import Control.Monad.State
import Data.List
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map, (!))
import qualified Data.Map as Map
data S = S
{ size :: Int
, rs :: Map Int [Int]
, cs :: Set (Int, Int) }
data Square = Square
Int -- square size
[Int] -- symbols in row-major order
deriving (Show)
type M = StateT S []
-- Get Latin squares
latin :: M Square
latin = do
n <- gets size
cells <- forM (pairs n) (\(r,c) -> getS r c)
return $ Square n cells
-- All locations in row-major order [(1,1),(1,2)..(n,n)]
pairs :: Int -> [(Int,Int)]
pairs n = (,) <$> [1..n] <*> [1..n]
-- Get a valid `s` for position `(r,c)`.
getS :: Int -> Int -> M Int
getS r c = do
s <- pickSFromRow r
pickCS c s
return s
-- Get an available `s` in row `r` from the `rs` urn.
pickSFromRow :: Int -> M Int
pickSFromRow r = do
urn <- gets rs
(s, rest) <- lift $ choices (urn ! r)
modify (\st -> st { rs = Map.insert r rest urn })
return s
-- Remove `(c,s)` from the `cs` urn.
pickCS :: Int -> Int -> M ()
pickCS c s = do
balls <- gets cs
guard $ (c,s) `Set.member` balls
modify (\st -> st { cs = Set.delete (c,s) balls })
-- Return all ways of removing one element from list.
choices :: [a] -> [(a,[a])]
choices = init . (zipWith f <$> inits <*> tails)
where f a (x:b) = (x, a++b)
f _ _ = error "choices: internal error"
-- Run an action in the M monad.
runM :: Int -> M a -> [a]
runM n act = evalStateT act (S n rs0 cs0)
where rs0 = Map.fromAscList $ zip [1..n] (repeat [1..n])
cs0 = Set.fromAscList $ pairs n
main :: IO ()
main = do
print $ runM 3 latin
print $ length (runM 4 latin)
print $ length (runM 5 latin)
Somewhat remarkably, modifying the program to produce only reduced Latin squares (i.e., with symbols [1..n] in order in both the first row and the first column) requires changing only two functions:
-- All locations in row-major order, skipping first row and column
-- i.e., [(2,2),(2,3)..(n,n)]
pairs :: Int -> [(Int,Int)]
pairs n = (,) <$> [2..n] <*> [2..n]
-- Run an action in the M monad.
runM :: Int -> M a -> [a]
runM n act = evalStateT act (S n rs0 cs0)
where -- skip balls [(1,1)..(n,n)] for first row
rs0 = Map.fromAscList $ map (\r -> (r, skip r)) [2..n]
-- skip balls [(1,1)..(n,n)] for first column
cs0 = Set.fromAscList $ [(c,s) | c <- [2..n], s <- skip c]
skip i = [1..(i-1)]++[(i+1)..n]
With these modifications, the resulting Square will include symbols in row-major order but skipping the first row and column. For example:
λ> runM 3 latin
[Square 3 [3,1,1,2]]
means:
1 2 3 fill in question marks 1 2 3
2 ? ? =====================> 2 3 1
3 ? ? in row-major order 3 1 2
This is fast enough to enumerate all 16,942,080 reduced Latin squares of size 7 in a few minutes:
$ stack ghc -- -O2 -main-is LatinReduced LatinReduced.hs && time ./LatinReduced
[1 of 1] Compiling LatinReduced ( LatinReduced.hs, LatinReduced.o )
Linking LatinReduced ...
16942080
real 3m9.342s
user 3m8.494s
sys 0m0.848s
I have written this function that computes Collatz sequences, and I see wildly varying times of execution depending on the spin I give it. Apparently it is related to something called "memoization", but I have a hard time understanding what it is and how it works, and, unfortunately, the relevant article on HaskellWiki, as well as the papers it links to, have all proven to not be easily surmountable. They discuss intricate details of the relative performance of highly layman-indifferentiable tree constructions, while what I miss must be some very basic, very trivial point that these sources neglect to mention.
This is the code. It is a complete program, ready to be built and executed.
module Main where
import Data.Function
import Data.List (maximumBy)
size :: (Integral a) => a
size = 10 ^ 6
-- Nail the basics.
collatz :: Integral a => a -> a
collatz n | even n = n `div` 2
| otherwise = n * 3 + 1
recollatz :: Integral a => a -> a
recollatz = fix $ \f x -> if (x /= 1)
then f (collatz x)
else x
-- Now, I want to do the counting with a tuple monad.
mocollatz :: Integral b => b -> ([b], b)
mocollatz n = ([n], collatz n)
remocollatz :: Integral a => a -> ([a], a)
remocollatz = fix $ \f x -> if x /= 1
then f =<< mocollatz x
else return x
-- Trivialities.
collatzLength :: Integral a => a -> Int
collatzLength x = (length . fst $ (remocollatz x)) + 1
collatzPairs :: Integral a => a -> [(a, Int)]
collatzPairs n = zip [1..n] (collatzLength <$> [1..n])
longestCollatz :: Integral a => a -> (a, Int)
longestCollatz n = maximumBy order $ collatzPairs n
where
order :: Ord b => (a, b) -> (a, b) -> Ordering
order x y = snd x `compare` snd y
main :: IO ()
main = print $ longestCollatz size
With ghc -O2 it takes about 17 seconds, without ghc -O2 -- about 22 seconds to deliver the length and the seed of the longest Collatz sequence starting at any point below size.
Now, if I make these changes:
diff --git a/Main.hs b/Main.hs
index c78ad95..9607fe0 100644
--- a/Main.hs
+++ b/Main.hs
## -1,6 +1,7 ##
module Main where
import Data.Function
+import qualified Data.Map.Lazy as M
import Data.List (maximumBy)
size :: (Integral a) => a
## -22,10 +23,15 ## recollatz = fix $ \f x -> if (x /= 1)
mocollatz :: Integral b => b -> ([b], b)
mocollatz n = ([n], collatz n)
-remocollatz :: Integral a => a -> ([a], a)
-remocollatz = fix $ \f x -> if x /= 1
- then f =<< mocollatz x
- else return x
+remocollatz :: (Num a, Integral b) => b -> ([b], a)
+remocollatz 1 = return 1
+remocollatz x = case M.lookup x (table mutate) of
+ Nothing -> mutate x
+ Just y -> y
+ where mutate x = remocollatz =<< mocollatz x
+
+table :: (Ord a, Integral a) => (a -> b) -> M.Map a b
+table f = M.fromList [ (x, f x) | x <- [1..size] ]
-- Trivialities.
-- Then it will take just about 4 seconds with ghc -O2, but I would not live long enough to see it complete without ghc -O2.
Looking at the details of cost centres with ghc -prof -fprof-auto -O2 reveals that the first version enters collatz about a hundred million times, while the patched one -- just about one and a half million times. This must be the reason of the speedup, but I have a hard time understanding the inner workings of this magic. My best idea is that we replace a portion of expensive recursive calls with O(log n) map lookups, but I don't know if it's true and why it depends so much on some godforsaken compiler flags, while, as I see it, such performance swings should all follow solely from the language.
Can I haz an explanation of what happens here, and why the performance differs so vastly between ghc -O2 and plain ghc builds?
P.S. There are two requirements to the achieving of automagical memoization highlighted elsewhere on Stack Overflow:
Make a function to be memoized a top-level name.
Make a function to be memoized a monomorphic one.
In line with these requirements, I rebuilt remocollatz as follows:
remocollatz :: Int -> ([Int], Int)
remocollatz 1 = return 1
remocollatz x = mutate x
mutate :: Int -> ([Int], Int)
mutate x = remocollatz =<< mocollatz x
Now it's as top level and as monomorphic as it gets. Running time is about 11 seconds, versus the similarly monomorphized table version:
remocollatz :: Int -> ([Int], Int)
remocollatz 1 = return 1
remocollatz x = case M.lookup x (table mutate) of
Nothing -> mutate x
Just y -> y
mutate :: Int -> ([Int], Int)
mutate = \x -> remocollatz =<< mocollatz x
table :: (Int -> ([Int], Int)) -> M.Map Int ([Int], Int)
table f = M.fromList [ (x, f x) | x <- [1..size] ]
-- Running in less than 4 seconds.
I wonder why the memoization ghc is supposedly performing in the first case here is almost 3 times slower than my dumb table.
Can I haz an explanation of what happens here, and why the performance differs so vastly between ghc -O2 and plain ghc builds?
Disclaimer: this is a guess, not verified by viewing GHC core output. A careful answer would do so to verify the conjectures outlined below. You can try peering through it yourself: add -ddump-simpl to your compilation line and you will get copious output detailing exactly what GHC has done to your code.
You write:
remocollatz x = {- ... -} table mutate {- ... -}
where mutate x = remocollatz =<< mocollatz x
The expression table mutate in fact does not depend on x; but it appears on the right-hand side of an equation that takes x as an argument. Consequently, without optimizations, this table is recomputed each time remocollatz is called (presumably even from inside the computation of table mutate).
With optimizations, GHC notices that table mutate does not depend on x, and floats it to its own definition, effectively producing:
fresh_variable_name = table mutate
where mutate x = remocollatz =<< mocollatz x
remocollatz x = case M.lookup x fresh_variable_name of
{- ... -}
The table is therefore computed just once for the entire program run.
don't know why it [the performance] depends so much on some godforsaken compiler flags, while, as I see it, such performance swings should all follow solely from the language.
Sorry, but Haskell doesn't work that way. The language definition tells clearly what the meaning of a given Haskell term is, but does not say anything about the runtime or memory performance needed to compute that meaning.
Another approach to memoization that works in some situations, like this one, is to use a boxed vector, whose elements are computed lazily. The function used to initialize each element can use other elements of the vector in its calculation. As long as the evaluation of an element of the vector doesn't loop and refer to itself, just the elements it recursively depends on will be evaluated. Once evaluated, an element is effectively memoized, and this has the further benefit that elements of the vector that are never referenced are never evaluated.
The Collatz sequence is a nearly ideal application for this technique, but there is one complication. The next Collatz value(s) in sequence from a value under the limit may be outside the limit, which would cause a range error when indexing the vector. I solved this by just iterating through the sequence until back under the limit and counting the steps to do so.
The following program takes 0.77 seconds to run unoptimized and 0.30 when optimized:
import qualified Data.Vector as V
limit = 10 ^ 6 :: Int
-- The Collatz function, which given a value returns the next in the sequence.
nextCollatz val
| odd val = 3 * val + 1
| otherwise = val `div` 2
-- Given a value, return the next Collatz value in the sequence that is less
-- than the limit and the number of steps to get there. For example, the
-- sequence starting at 13 is: [13, 40, 20, 10, 5, 16, 8, 4, 2, 1], so if
-- limit is 100, then (nextCollatzWithinLimit 13) is (40, 1), but if limit is
-- 15, then (nextCollatzWithinLimit 13) is (10, 3).
nextCollatzWithinLimit val = (firstInRange, stepsToFirstInRange)
where
firstInRange = head rest
stepsToFirstInRange = 1 + (length biggerThanLimit)
(biggerThanLimit, rest) = span (>= limit) (tail collatzSeqStartingWithVal)
collatzSeqStartingWithVal = iterate nextCollatz val
-- A boxed vector holding Collatz length for each index. The collatzFn used
-- to generate the value for each element refers back to other elements of
-- this vector, but since the vector elements are only evaluated as needed and
-- there aren't any loops in the Collatz sequences, the values are calculated
-- only as needed.
collatzVec :: V.Vector Int
collatzVec = V.generate limit collatzFn
where
collatzFn :: Int -> Int
collatzFn index
| index <= 1 = 1
| otherwise = (collatzVec V.! nextWithinLimit) + stepsToGetThere
where
(nextWithinLimit, stepsToGetThere) = nextCollatzWithinLimit index
main :: IO ()
main = do
-- Use a fold through the vector to find the longest Collatz sequence under
-- the limit, and keep track of both the maximum length and the initial
-- value of the sequence, which is the index.
let (maxLength, maxIndex) = V.ifoldl' accMaxLen (0, 0) collatzVec
accMaxLen acc#(accMaxLen, accMaxIndex) index currLen
| currLen <= accMaxLen = acc
| otherwise = (currLen, index)
putStrLn $ "Max Collatz length below " ++ show limit ++ " is "
++ show maxLength ++ " at index " ++ show maxIndex
I'm aware that with Data.Bits I can do bitwise manipulation on anything with an instance of Bits. That however is not what I want. I explicitly want to manipulate a list of Bools which represent the bit value at each bit index.
More precisely I want to manipulate the bits of a 1-byte integer.
I have managed to get this to work using map/testBit and foldl'/setBit. Which you can see from my tests in ghci below:
λ import Data.Bits
λ import Data.List
λ let convertToBits :: Int -> [Bool] ; convertToBits x = map (testBit x) [7,6..0]
λ let convertFromBits :: [Bool] -> Int ;
convertFromBits bs = foldl' (\acc (i,b) -> if b then setBit acc i else acc) 0 $ zip [7,6..0] bs
λ convertToBits 255
[True,True,True,True,True,True,True,True]
λ convertToBits 2
[False,False,False,False,False,False,True,False]
λ convertFromBits [False,False,False,True,True,False,False,True]
25
λ convertFromBits $ convertToBits 255
255
λ convertFromBits $ convertToBits 10
10
However this doesn't feel like a particularly great implementation. I'm wondering if this (or something similar) might be in standard library?
EDIT ------
A quick final note, even though I have answered the question below, I thought I'd add a bit more in case anyone stumbles on this and knows a better answer. I'll be storing all of the bits in a UArray which will bit-pack the Bool values.
In thinking about it, what would be really good is a function that converts a Word8 into a UArray ix Bool. I'd like to think there is some efficient / simple way to do this but I don't know enough about the low level stuff to work it out.
This was previously an edit in my question. I would still like to know if there is anything better. I'm imagining something that literally coerces a Word8 into an array of its constituent bits but perhaps that's a bit lofty...
The package bitwise which does a similar implementation to my own but using bitshifting for their fromList function. In lieu of a better idea this suggests that using Data.Bits combined with some sort of map and some sort of fold is the canonical solution:
-- | Convert a little-endian list of bits to 'Bits'.
{-# INLINE fromListLE #-}
fromListLE :: (Num b, Bits b) => [Bool] {- ^ \[least significant bit, ..., most significant bit\] -} -> b
fromListLE = foldr f 0
where
f b i = fromBool b .|. (i `shiftL` 1)
-- | Convert a 'Bits' (with a defined 'bitSize') to a list of bits, in
-- little-endian order.
{-# INLINE toListLE #-}
toListLE :: (Num b, Bits b) => b -> [Bool] {- ^ \[least significant bit, ..., most significant bit\] -}
toListLE b = P.map (testBit b) [0 .. bitSize b - 1]
This isn't thoroughly tested, but it's a proposal for a native implementation.
toBitsBySize :: Int -> Int -> [Bool]
toBitsBySize 0 x = []
toBitsBySize sz 0 = [False | i <- [1..sz]]
toBitsBySize sz x = if k == 0
then False : (toBitsBySize n x)
else True : (toBitsBySize n (x - k*m))
where n = sz - 1
m = 2^n
k = x `div` m
toBits8 = toBitsBySize 8
Some examples:
*Main> toBits8 2
[False,False,False,False,False,False,True,False]
*Main> toBits8 255
[True,True,True,True,True,True,True,True]
*Main> toBits8 129
[True,False,False,False,False,False,False,True]
I worked my way through Mike Vanier's monad tutorial (which is excellent) and I'm working on a few of the exercises in his post on how to use a "State" monad.
In particular, he suggests an exercise which consists of writing functions for factorial and fibonacci using a State monad. I gave it a shot and came up with the answers below. (I find do notation pretty confusing, hence my choice of syntax).
Neither of my implementations look particularly "Haskell-y" and, in the interest of not internalizing bad practices, I thought I'd ask folks for input on how they would've gone about implementing these functions (using the state monad). Is it possibly to write this code far more simply (aside from switching to do notation)? I strongly suspect this is the case.
I'm aware that it's a bit impractical to use a state monad for this purpose but this is purely a learning exercise - pun most certainly intended.
That said, the performance is not that much worse: in order to calc the factorial of 100000 (the answer is ~21k digits long), the unfoldr version took ~1.2 sec (in GHCi) vs. ~1.5 sec for the state monad version.
import Control.Monad.State (State, get, put, evalState)
import Data.List (unfoldr)
fibonacci :: Integer -> Integer
fibonacci 0 = 0
fibonacci n = evalState fib_state (1,0,1,n)
fib_state :: State (Integer,Integer,Integer,Integer) Integer
fib_state = get >>=
\s ->
let (p1,p2,ctr,n) = s
in case compare ctr n of
LT -> put (p1+p2, p1, ctr+1, n) >> fib_state
_ -> return p1
factorial :: Integer -> Integer
factorial n = evalState fact_state (n,1)
fact_state :: State (Integer,Integer) Integer
fact_state = get >>=
\s ->
let (n,f) = s
in case n of
0 -> return f
_ -> put (n-1,f*n) >> fact_state
-------------------------------------------------------------------
--Functions below are used only to test output of functions above
factorial' :: Integer -> Integer
factorial' n = product [1..n]
fibonacci' :: Int -> Integer
fibonacci' 0 = 1
fibonacci' 1 = 1
fibonacci' n =
let getFst (a,b,c) = a
in getFst
$ last
$ unfoldr (\(p1,p2,cnt) ->
if cnt == n
then Nothing
else Just ((p1,p2,cnt)
,(p1+p2,p1,cnt+1))
) (1,1,1)
Your functions seem to be a bit more complicated than they need to be, but you have the right idea. For the factorial, all you need to keep track of is the current number you're multiplying by and the number that you've accumulated so far. So, we'll say that State Int Int is a computation that operates on the current number on the state and returns the number that you've multiplied up until now:
fact_state :: State Int Int
fact_state = get >>= \x -> if x <= 1
then return 1
else (put (x - 1) >> fmap (*x) fact_state)
factorial :: Int -> Int
factorial = evalState fact_state
Prelude Control.Monad.State.Strict Control.Applicative> factorial <$> [1..10]
[1,2,6,24,120,720,5040,40320,362880,3628800]
The fibonacci sequence is similar. You need to keep the last two numbers in order to know what you're going to be adding together, and how far you've gone so far:
fibs_state :: State (Int, Int, Int) Int
fibs_state = get >>= \(x1, x2, n) -> if n == 0
then return x1
else (put (x2, x1+x2, n-1) >> fibs_state)
fibonacci n = evalState fibs_state (0, 1, n)
Prelude Control.Monad.State.Strict Control.Applicative> fibonacci <$> [1..10]
[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]
Two stylistic suggestions:
\s ->
let (p1,p2,ctr,n) = s
in ...
is equivalent to:
\(p1,p2,ctr,n) -> ...
and your case statement for fib_state may be written with an if statement:
if ctr < n
then put (p1+p2, p1, ctr+1, n) >> fib_state
else return p1
Here is my first Haskell program. What parts would you write in a better way?
-- Multiplication table
-- Returns n*n multiplication table in base b
import Text.Printf
import Data.List
import Data.Char
-- Returns n*n multiplication table in base b
mulTable :: Int -> Int -> String
mulTable n b = foldl (++) (verticalHeader n b w) (map (line n b w) [0..n])
where
lo = 2* (logBase (fromIntegral b) (fromIntegral n))
w = 1+fromInteger (floor lo)
verticalHeader :: Int -> Int -> Int -> String
verticalHeader n b w = (foldl (++) tableHeader columnHeaders)
++ "\n"
++ minusSignLine
++ "\n"
where
tableHeader = replicate (w+2) ' '
columnHeaders = map (horizontalHeader b w) [0..n]
minusSignLine = concat ( replicate ((w+1)* (n+2)) "-" )
horizontalHeader :: Int -> Int -> Int -> String
horizontalHeader b w i = format i b w
line :: Int -> Int -> Int -> Int -> String
line n b w y = (foldl (++) ((format y b w) ++ "|" )
(map (element b w y) [0..n])) ++ "\n"
element :: Int -> Int -> Int -> Int -> String
element b w y x = format (y * x) b w
toBase :: Int -> Int -> [Int]
toBase b v = toBase' [] v where
toBase' a 0 = a
toBase' a v = toBase' (r:a) q where (q,r) = v `divMod` b
toAlphaDigits :: [Int] -> String
toAlphaDigits = map convert where
convert n | n < 10 = chr (n + ord '0')
| otherwise = chr (n + ord 'a' - 10)
format :: Int -> Int -> Int -> String
format v b w = concat spaces ++ digits ++ " "
where
digits = if v == 0
then "0"
else toAlphaDigits ( toBase b v )
l = length digits
spaceCount = if (l > w) then 0 else (w-l)
spaces = replicate spaceCount " "
Here are some suggestions:
To make the tabularity of the computation more obvious, I would pass the list [0..n] to the line function rather than passing n.
I would further split out the computation of the horizontal and vertical axes so that they are passed as arguments to mulTable rather than computed there.
Haskell is higher-order, and almost none of the computation has to do with multiplication. So I would change the name of mulTable to binopTable and pass the actual multiplication in as a parameter.
Finally, the formatting of individual numbers is repetitious. Why not pass \x -> format x b w as a parameter, eliminating the need for b and w?
The net effect of the changes I am suggesting is that you build a general higher-order function for creating tables for binary operators. Its type becomes something like
binopTable :: (i -> String) -> (i -> i -> i) -> [i] -> [i] -> String
and you wind up with a much more reusable function—for example, Boolean truth tables should be a piece of cake.
Higher-order and reusable is the Haskell Way.
You don't use anything from import Text.Printf.
Stylistically, you use more parentheses than necessary. Haskellers tend to find code more readable when it's cleaned of extraneous stuff like that. Instead of something like h x = f (g x), write h = f . g.
Nothing here really requires Int; (Integral a) => a ought to do.
foldl (++) x xs == concat $ x : xs: I trust the built-in concat to work better than your implementation.
Also, you should prefer foldr when the function is lazy in its second argument, as (++) is – because Haskell is lazy, this reduces stack space (and also works on infinite lists).
Also, unwords and unlines are shortcuts for intercalate " " and concat . map (++ "\n") respectively, i.e. "join with spaces" and "join with newlines (plus trailing newline)"; you can replace a couple things by those.
Unless you use big numbers, w = length $ takeWhile (<= n) $ iterate (* b) 1 is probably faster. Or, in the case of a lazy programmer, let w = length $ toBase b n.
concat ( (replicate ((w+1)* (n+2)) "-" ) == replicate ((w+1) * (n+2)) '-' – not sure how you missed this one, you got it right just a couple lines up.
You do the same thing with concat spaces, too. However, wouldn't it be easier to actually use the Text.Printf import and write printf "%*s " w digits?
Norman Ramsey gave excellent high-level (design) suggestions; Below are some low-level ones:
First, consult with HLint. HLint is a friendly program that gives you rudimentary advice on how to improve your Haskell code!
In your case HLint gives 7 suggestions. (mostly about redundant brackets)
Modify your code according to HLint's suggestions until it likes what you feed it.
More HLint-like stuff:
concat (replicate i "-"). Why not replicate i '-'?
Consult with Hoogle whenever there is reason to believe that a function you need is already available in Haskell's libraries. Haskell comes with tons of useful functions so Hoogle should come in handy quite often.
Need to concatenate strings? Search for [String] -> String, and voila you found concat. Now go replace all those folds.
The previous search also suggested unlines. Actually, this even better suits your needs. It's magic!
Optional: pause and thank in your heart to Neil M for making Hoogle and HLint, and thank others for making other good stuff like Haskell, bridges, tennis balls, and sanitation.
Now, for every function that takes several arguments of the same type, make it clear which means what, by giving them descriptive names. This is better than comments, but you can still use both.
So
-- Returns n*n multiplication table in base b
mulTable :: Int -> Int -> String
mulTable n b =
becomes
mulTable :: Int -> Int -> String
mulTable size base =
To soften the extra characters blow of the previous suggestion: When a function is only used once, and is not very useful by itself, put it inside its caller's scope in its where clause, where it could use the callers' variables, saving you the need to pass everything to it.
So
line :: Int -> Int -> Int -> Int -> String
line n b w y =
concat
$ format y b w
: "|"
: map (element b w y) [0 .. n]
element :: Int -> Int -> Int -> Int -> String
element b w y x = format (y * x) b w
becomes
line :: Int -> Int -> Int -> Int -> String
line n b w y =
concat
$ format y b w
: "|"
: map element [0 .. n]
where
element x = format (y * x) b w
You can even move line into mulTable's where clause; imho, you should.
If you find a where clause nested inside another where clause troubling, then I suggest to change your indentation habits. My recommendation is to use consistent indentation of always 2 or always 4 spaces. Then you can easily see, everywhere, where the where in the other where is at. ok
Below's what it looks like (with a few other changes in style):
import Data.List
import Data.Char
mulTable :: Int -> Int -> String
mulTable size base =
unlines $
[ vertHeaders
, minusSignsLine
] ++ map line [0 .. size]
where
vertHeaders =
concat
$ replicate (cellWidth + 2) ' '
: map horizontalHeader [0 .. size]
horizontalHeader i = format i base cellWidth
minusSignsLine = replicate ((cellWidth + 1) * (size + 2)) '-'
cellWidth = length $ toBase base (size * size)
line y =
concat
$ format y base cellWidth
: "|"
: map element [0 .. size]
where
element x = format (y * x) base cellWidth
toBase :: Integral i => i -> i -> [i]
toBase base
= reverse
. map (`mod` base)
. takeWhile (> 0)
. iterate (`div` base)
toAlphaDigit :: Int -> Char
toAlphaDigit n
| n < 10 = chr (n + ord '0')
| otherwise = chr (n + ord 'a' - 10)
format :: Int -> Int -> Int -> String
format v b w =
spaces ++ digits ++ " "
where
digits
| v == 0 = "0"
| otherwise = map toAlphaDigit (toBase b v)
spaces = replicate (w - length digits) ' '
0) add a main function :-) at least rudimentary
import System.Environment (getArgs)
import Control.Monad (liftM)
main :: IO ()
main = do
args <- liftM (map read) $ getArgs
case args of
(n:b:_) -> putStrLn $ mulTable n b
_ -> putStrLn "usage: nntable n base"
1) run ghc or runhaskell with -Wall; run through hlint.
While hlint doesn't suggest anything special here (only some redundant brackets), ghc will tell you that you don't actually need Text.Printf here...
2) try running it with base = 1 or base = 0 or base = -1
If you want multiline comments use:
{- A multiline
comment -}
Also, never use foldl, use foldl' instead, in cases where you are dealing with large lists which must be folded. It is more memory efficient.
A brief comments saying what each function does, its arguments and return value, is always good. I had to read the code pretty carefully to fully make sense of it.
Some would say if you do that, explicit type signatures may not be required. That's an aesthetic question, I don't have a strong opinion on it.
One minor caveat: if you do remove the type signatures, you'll automatically get the polymorphic Integral support ephemient mentioned, but you will still need one around toAlphaDigits because of the infamous "monomorphism restriction."