Fold that's both constant-space and short-circuiting - haskell

I'm trying to build a Haskell function that does basically the same thing as Prelude's product. Unlike that function, however, it should have these two properties:
It should operate in constant space (ignoring the fact that some numeric types like Integer aren't). For example, I want myProduct (replicate 100000000 1) to eventually return 1, unlike Prelude's product which uses up all of my RAM and then gives *** Exception: stack overflow.
It should short-circuit when it encounters a 0. For example, I want myProduct (0:undefined) to return 0, unlike Prelude's product which gives *** Exception: Prelude.undefined.
Here's what I've come up with so far:
myProduct :: (Eq n, Num n) => [n] -> n
myProduct = go 1
where go acc (x:xs) = if x == 0 then 0 else acc `seq` go (acc * x) xs
go acc [] = acc
That works exactly how I want it to for lists, but I'd like to generalize it to have type (Foldable t, Eq n, Num n) => t n -> n. Is it possible to do this with any of the folds? If I just use foldr, then it will short-circuit but won't be constant-space, and if I just use foldl', then it will be constant-space but won't short-circuit.

If you spell your function slightly differently, it's more obvious how to turn it into a foldr. Namely:
myProduct :: (Eq n, Num n) => [n] -> n
myProduct = flip go 1 where
go (x:xs) = if x == 0 then \acc -> 0 else \acc -> acc `seq` go xs (acc * x)
go [] = \acc -> acc
Now go has got that foldr flavor, and we can just fill in the holes.
myProduct :: (Foldable t, Eq n, Num n) => t n -> n
myProduct = flip go 1 where
go = foldr
(\x f -> if x == 0 then \acc -> 0 else \acc -> acc `seq` f (acc * x))
(\acc -> acc)
Hopefully you can see where each of those pieces came from in the previous explicit-recursion style and how mechanical the transformation is. Then I'd make a few aesthetic tweaks:
myProduct :: (Foldable t, Eq n, Num n) => t n -> n
myProduct xs = foldr step id xs 1 where
step 0 f acc = 0
step x f acc = f $! acc * x
And we're all done! A bit of quick testing in ghci reveals that it still short-circuits on 0 as required and uses constant space when specialized to lists.

You might be looking for foldM. Instantiate it with m = Either b and you get short circuiting behavior (or Maybe, depends if you have many possible early exit values, or one known in advance).
foldM :: (Foldable t, Monad m) => (b -> a -> m b) -> b -> t a -> m b
I recall discussions whether there should be foldM', but IIRC GHC does the right thing most of the time.
import Control.Monad
import Data.Maybe
myProduct :: (Foldable t, Eq n, Num n) => t n -> n
myProduct = fromMaybe 0 . foldM go 1
where go acc x = if x == 0 then Nothing else Just $! acc * x

Related

Is there any terminating fold in Haskell?

I need some kind of fold which can terminate if I already have the data I want.
For example I need to find first 3 numbers which are greater than 5. I decided to use Either for termination and my code looks like this:
terminatingFold :: ([b] -> a -> Either [b] [b]) -> [a] -> [b]
terminatingFold f l = reverse $ either id id $ fold [] l
where fold acc [] = Right acc
fold acc (x:xs) = f acc x >>= flip fold xs
first3NumsGreater5 acc x =
if length acc >= 3
then Left acc
else Right (if x > 5 then (x : acc) else acc)
Are there some more clever/generic approaches?
The result of your function is a list, and it would be desirable if it were produced lazily, that is, extracting one item from the result should only require evaluating the input list up until the item is found there.
Unfolds are under-appreciated for these kinds of tasks. Instead of focusing on "consuming" the input list, let's think of it as a seed from which (paired with some internal accumulator) we can produce the result, element by element.
Let's define a Seed type that contains a generic accumulator paired with the as-yet unconsumed parts of the input:
{-# LANGUAGE NamedFieldPuns #-}
import Data.List (unfoldr)
data Seed acc input = Seed {acc :: acc, pending :: [input]}
Now let's reformulate first3NumsGreater5 as a function that either produces the next output element from the Seed, of signals that there aren't any more elements:
type Counter = Int
first3NumsGreater5 :: Seed Counter Int -> Maybe (Int, Seed Counter Int)
first3NumsGreater5 (Seed {acc, pending})
| acc >= 3 =
Nothing
| otherwise =
case dropWhile (<= 5) pending of
[] -> Nothing
x : xs -> Just (x, Seed {acc = succ acc, pending = xs})
Now our main function can be written in terms of unfoldr:
unfoldFromList ::
(Seed acc input -> Maybe (output, Seed acc input)) ->
acc ->
[input] ->
[output]
unfoldFromList next acc pending = unfoldr next (Seed {acc, pending})
Putting it to work:
main :: IO ()
main = print $ unfoldFromList first3NumsGreater5 0 [0, 6, 2, 7, 9, 10, 11]
-- [6,7,9]
Normally an early termination-capable fold is foldr with the combining function which is non-strict in its second argument. But, its information flow is right-to-left (if any), while you want it left-to-right.
A possible solution is to make foldr function as a left fold, which can then be made to stop early:
foldlWhile :: Foldable t
=> (a -> Bool) -> (r -> a -> r) -> r
-> t a -> r
foldlWhile t f a xs = foldr cons (\acc -> acc) xs a
where
cons x r acc | t x = r (f acc x)
| otherwise = acc
You will need to tweak this for t to test the acc instead of x, to fit your purposes.
This function is foldlWhile from https://wiki.haskell.org/Foldl_as_foldr_alternative, re-written a little. foldl'Breaking from there might fit the bill a bit better.
foldr with the lazy reducer function can express corecursion perfectly fine just like unfoldr does.
And your code is already lazy: terminatingFold (\acc x -> Left acc) [1..] => []. That's why I'm not sure if this answer is "more clever", as you've requested.
edit: following a comment by #danidiaz, to make it properly lazy you'd have to code it as e.g.
first3above5 :: (Foldable t, Ord a, Num a)
=> t a -> [a]
first3above5 xs = foldr cons (const []) xs 0
where
cons x r i | x > 5 = if i==2 then [x]
else x : r (i+1)
| otherwise = r i
This can be generalized further by abstracting the test and the count.
Of course it's just reimplementing take 3 . filter (> 5), but shows how to do it in general with foldr.

Identity of the "accumulating parameter" of the foldr function

the foldr function:
foldr :: (a -> b -> b) -> b -> [a] -> b
foldr func acc [] = acc
foldr func acc (x:xs) = func x (foldr func acc xs)
catches patterns like those (left side)
and makes them simpler (right side)
sum :: [Integer] -> Integer | sum :: [Integer] -> Integer
sum [] = 0 | sum [] = 0
sum (x:xs) = x + sum xs | sum (x:xs) = foldr (+) 0 xs
|
product :: [Integer] -> Integer | product :: [Integer] -> Integer
product [] = 0 | product [] = 0
product (x:xs) = x * product xs | product (x:xs) = foldr (*) 1 xs
|
concat :: [[a]] -> [a] | concat :: [[a]] -> [a]
concat [] = [] | concat [] = []
concat (x:xs) = x ++ concat xs | concat (x:xs) = foldr (++) [] xs
----------------------------------------------------------------------
not using folds | using folds
one thing I noticed was that the acc argument, provided as input for the fold,
seems to be exactly the neutral element / identity element of that function.
In Mathematics the neutral element of the addition operation + is 0
because n + 0 = n, n ∈ ℝ
it doesn't change anything, in other words:
With this neutral element provided as an input for the addition function, the summand equals the sum.
(+) summand 0 = summand or summand + 0 = summand
The same goes for multiplication, the product of the factor and the identiy equals the factor itelf:
(*) factor 1 = factor
So is this just a coincidence or is there someting bigger behind ?
You're exactly right. We very often want to pass an "identity"-like element to foldr, so that the "starting point" doesn't affect the result at all. In fact, this is codified in Haskell with the Monoid typeclass. A monoid is an associative binary operation with an identity. The examples you provide are all examples of a monoid, and they all exist in Haskell.
+ on any Num is codified as a monoid over the Sum newtype.
* on any Num is codified as a monoid over the Product newtype.
++ on any list is codified as a monoid on [a].
And in fact we can go one step further. Folding over a monoid is such a common practice that we can do it automatically with fold (or foldMap, if you need to disambiguate). For instance,
import Data.Foldable
import Data.Monoid
sum :: Num a => [a] -> a
sum = getSum . foldMap Sum
product :: Num a => [a] -> a
product = getProduct . foldMap Product
concat :: [[a]] -> [a]
concat = fold
If you look in the source for Foldable, you can see that fold and foldMap are actually defined in terms of foldr on a monoid, so this is doing the exact same thing you just described.
You can find the full list of (built-in) Monoid instances on Hackage, but a few others that you might find of interest:
|| on Booleans is a monoid with the Any newtype.
&& on Booleans is a monoid with the All newtype.
Function composition is a monoid with the Endo newtype (short for "endomorphism")
As an exercise, you might consider trying to pinpoint the identity of each of these operations.

How to break out from a fold function in haskell when the accumulator met a certain condition?

I'm calculating the sum of a list after applying someFunction to every element of it like so:
sum (map someFunction myList)
someFunction is very resource heavy so to optimise it I want to stop calculating the sum if it goes above a certain threshold.
It seems like I need to use fold but I don't know how to break out if it if the accumulator reaches the threshold. My guess is to somehow compose fold and takeWhile but I'm not exactly sure how.
Another technique is to use a foldM with Either to capture the early termination effect. Left signals early termination.
import Control.Monad(foldM)
sumSome :: (Num n,Ord n) => n -> [n] -> Either n n
sumSome thresh = foldM f 0
where
f a n
| a >= thresh = Left a
| otherwise = Right (a+n)
To ignore the exit status, just compose with either id id.
sumSome' :: (Num n,Ord n) => n -> [n] -> n
sumSome' n = either id id . sumSome n
One of the options would be using scanl function, which returns a list of intermediate calculations of foldl.
Thus, scanl1 (+) (map someFunction myList) will return the intermediate sums of your calculations. And since Haskell is a lazy language it won't calculate all the values of myList until you need it. For example:
take 5 $ scanl1 (+) (map someFunction myList)
will calculate someFunction 5 times and return the list of these 5 results.
After that you can use either takeWhile or dropWhile and stop the calculation, when a certain condition is True. For example:
head $ dropWhile (< 1000) $ scanl1 (+) [1..1000000000]
will stop the calculation, when sum of the numbers reaches 1000 and returns 1035.
This will do what you ask about without building the intermediate list as scanl' would (and scanl would even cause a thunks build-up on top of that):
foldl'Breaking break reduced reducer acc list =
foldr cons (\acc -> acc) list acc
where
cons x r acc | break acc x = reduced acc x
| otherwise = r $! reducer acc x
cf. related wiki page.
Use a bounded addition operator instead of (+) with foldl.
foldl (\b a -> b + if b > someThreshold then 0 else a) 0 (map someFunction myList)
Because Haskell is non-strict, only calls to someFunction that are necessary to evaluate the if-then-else are themselves evaluated. fold still traverses the entire list.
> foldl (\b a -> b + if b > 10 then 0 else a) 0 (map (trace "foo") [1..20])
foo
foo
foo
foo
foo
15
sum [1..5] > 10, and you can see that trace "foo" only executes 5 times, not 20.
Instead of foldl, though, you should use the strict version foldl' from Data.Foldable.
You could try making your own sum function, maybe call it boundedSum that takes
an Integer upper bound
an [Integer] to sum over
a "sum up until this point" value to be compared with the upper bound
and returns the sum of the list.
boundedSum :: Integer -> [Integer] -> Integer -> Integer
boundedSum upperBound (x : xs) prevSum =
let currentSum = prevSum + x
in
if currentSum > upperBound
then upperBound
else boundedSum upperBound xs currentSum
boundedSum upperBound [] prevSum =
prevSum
I think this way you won't "eat up" more of the list if the sum up until the current element exceeds upperBound.
EDIT: The answers to this question suggest better techniques than mine and the question itself looks rather similar to yours.
This is a possible solution:
last . takeWhile (<=100) . scanl (+) 0 . map (^2) $ [1..]
Dissected:
take your starting list ([1..] in the example)
map your expensive function ((^2))
compute partial sums scanl (+) 0
stop after the partial sums become too large (keep those (<=100))
take the last one
If performance matters, also try scanl', which might improve it.
Something like this using until :: (a -> Bool) -> (a -> a) -> a -> a from the Prelude
sumUntil :: Real a => a -> [a] -> a
sumUntil threshold u = result
where
(_, result) = until stopCondition next (u, 0)
next :: Real a => ([a], a) -> ([a], a)
next ((x:xs), y) = (xs, x + y)
stopCondition :: Real a => ([a], a) -> Bool
stopCondition (ls, x) = null ls || x > threshold
Then apply
sumUntil 10 (map someFunction myList)
This post is already a bit older but I'd like to mention a way to generalize the nice code of #trevor-cook above to break fold with the additional possibility to return not only a default value or the accumulator but also the index and element of the list where the breaking condition was satisfied:
import Control.Monad (foldM)
breakFold step initialValue list exitCondition exitFunction =
either id (exitFunction (length list) (last list))
(foldM f initialValue (zip [0..] list))
where f acc (index,x)
| exitCondition index x acc
= Left (exitFunction index x acc)
| otherwise = Right (step index x acc)
It also only requires to import foldM. Examples for the usage are:
mysum thresh list = breakFold (\i x acc -> x + acc) 0 list
(\i x acc -> x + acc > thresh)
(\i x acc -> acc)
myprod thresh list = breakFold (\i x acc -> x * acc) 1 list
(\i x acc -> acc == thresh)
(\i x acc -> (i,x,acc))
returning
*myFile> mysum 42 [1,1..]
42
*myFile> myprod 0 ([1..5]++[0,0..])
(6,0,0)
*myFile> myprod 0 (map (\n->1/n) [1..])
(178,5.58659217877095e-3,0.0)
In this way, one can use the index and the last evaluated list value as input for further functions.
Despite the age of this post, I'll add a possible solution. I like continuations because I find them very useful in terms of flow control.
breakableFoldl
:: (b -> a -> (b -> r) -> (b -> r) -> r)
-> b
-> [a]
-> (b -> r)
-> r
breakableFoldl f b (x : xs) = \ exit ->
f b x exit $ \ acc ->
breakableFoldl f acc xs exit
breakableFoldl _ b _ = ($ b)
breakableFoldr
:: (a -> b -> (b -> r) -> (b -> r) -> r)
-> b
-> [a]
-> (b -> r)
-> r
breakableFoldr f b l = \ exit ->
fix (\ fold acc xs next ->
case xs of
x : xs' -> fold acc xs' (\ acc' -> f x acc' exit next)
_ -> next acc) b l exit
exampleL = breakableFoldl (\ acc x exit next ->
( if acc > 15
then exit
else next . (x +)
) acc
) 0 [1..9] print
exampleR = breakableFoldr (\ x acc exit next ->
( if acc > 15
then exit
else next . (x +)
) acc
) 0 [1..9] print

Lagrange Interpolation for a schema based on Shamir's Secret Sharing

I'm trying to debug an issue with an implementation of a threshold encryption scheme. I've posted this question on crypto to get some help with the actual scheme but was hoping to get a sanity check on the simplified code I am using.
Essentially the the crypto system uses Shamir's Secret Sharing to combine the shares of a key. The polynomial is each member of the list 'a' multiplied by a increasing power of the parameter of the polynomial. I've left out the mod by prime to simplify the code as the actual implementation uses PBC via a Haskell wrapper.
I have for the polynomial
poly :: [Integer] -> Integer -> Integer
poly as xi = (f 1 as)
where
f _ [] = 0
f 0 _ = 0
f s (a:as) = (a * s) + f (s * xi) as
The Lagrange interpolation is:
interp0 :: [(Integer, Integer)] -> Integer
interp0 xys = round (sum $ zipWith (*) ys $ fmap (f xs) xs)
where
xs = map (fromIntegral .fst) xys
ys = map (fromIntegral .snd) xys
f :: (Eq a, Fractional a) => [a] -> a -> a
f xs xj = product $ map (p xj) xs
p :: (Eq a, Fractional a) => a -> a -> a
p xj xm = if xj == xm then 1 else negate (xm / (xj - xm))
and the split and combination code is
execPoly as#(a0:_) = do
let xs = zipWith (,) [0..] (fmap (poly as) [0..100])
let t = length as + 1
let offset = 1
let shares = take t (drop offset xs)
let sm2 = interp0 shares
putText ("poly and interp over " <> show as <> " = " <> show sm2 <> ". Should be " <> show a0)
main :: IO ()
main = do
execPoly [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150] --1
execPoly [10,20,30,40,50,60,70,80] -- 2
execPoly(1) fails to combine to 10 but execPoly(2) combines correctly. The magic threshold seems to be 8.
Is my code correct? I am missing something in the implementation that limits the threshold size to 8?
As MathematicalOrchid said it was a precision problem.
Updated the code to:
f :: (Eq a, Integral a) => [a] -> a -> Ratio a
f xs xj = product $ map (p xj) xs
p :: (Eq a, Integral a)=> a -> a -> Ratio a
p xj xm = if xj == xm then (1 % 1) else (negate xm) % (xj - xm)
And it works as expected.

Transforming a function that computes a fixed point

I have a function which computes a fixed point in terms of iterate:
equivalenceClosure :: (Ord a) => Relation a -> Relation a
equivalenceClosure = fst . List.head -- "guaranteed" to exist
. List.dropWhile (uncurry (/=)) -- removes pairs that are not equal
. U.List.pairwise (,) -- applies (,) to adjacent list elements
. iterate ( reflexivity
. symmetry
. transitivity
)
Notice that we can abstract from this to:
findFixedPoint :: (a -> a) -> a -> a
findFixedPoint f = fst . List.head
. List.dropWhile (uncurry (/=)) -- dropWhile we have not reached the fixed point
. U.List.pairwise (,) -- applies (,) to adjacent list elements
. iterate
$ f
Can this function be written in terms of fix? It seems like there should be a transformation from this scheme to something with fix in it, but I don't see it.
There's quite a bit going on here, from the mechanics of lazy evaluation, to the definition of a fixed point to the method of finding a fixed point. In short, I believe you may be incorrectly interchanging the fixed point of function application in the lambda calculus with your needs.
It may be helpful to note that your implementation of finding the fixed-point (utilizing iterate) requires a starting value for the sequence of function application. Contrast this to the fix function, which requires no such starting value (As a heads up, the types give this away already: findFixedPoint is of type (a -> a) -> a -> a, whereas fix has type (a -> a) -> a). This is inherently because the two functions do subtly different things.
Let's dig into this a little deeper. First, I should say that you may need to give a little bit more information (your implementation of pairwise, for example), but with a naive first-try, and my (possibly flawed) implementation of what I believe you want out of pairwise, your findFixedPoint function is equivalent in result to fix, for a certain class of functions only
Let's take a look at some code:
{-# LANGUAGE RankNTypes #-}
import Control.Monad.Fix
import qualified Data.List as List
findFixedPoint :: forall a. Eq a => (a -> a) -> a -> a
findFixedPoint f = fst . List.head
. List.dropWhile (uncurry (/=)) -- dropWhile we have not reached the fixed point
. pairwise (,) -- applies (,) to adjacent list elements
. iterate f
pairwise :: (a -> a -> b) -> [a] -> [b]
pairwise f [] = []
pairwise f (x:[]) = []
pairwise f (x:(xs:xss)) = f x xs:pairwise f xss
contrast this to the definition of fix:
fix :: (a -> a) -> a
fix f = let x = f x in x
and you'll notice that we're finding a very different kind of fixed-point (i.e. we abuse lazy evaluation to generate a fixed point for function application in the mathematical sense, where we only stop evaluation iff* the resulting function, applied to itself, evaluates to the same function).
For illustration, let's define a few functions:
lambdaA = const 3
lambdaB = (*)3
and let's see the difference between fix and findFixedPoint:
*Main> fix lambdaA -- evaluates to const 3 (const 3) = const 3
-- fixed point after one iteration
3
*Main> findFixedPoint lambdaA 0 -- evaluates to [const 3 0, const 3 (const 3 0), ... thunks]
-- followed by grabbing the head.
3
*Main> fix lambdaB -- does not stop evaluating
^CInterrupted.
*Main> findFixedPoint lambdaB 0 -- evaluates to [0, 0, ...thunks]
-- followed by grabbing the head
0
now if we can't specify the starting value, what is fix used for? It turns out that by adding fix to the lambda calculus, we gain the ability to specify the evaluation of recursive functions. Consider fact' = \rec n -> if n == 0 then 1 else n * rec (n-1), we can compute the fixed point of fact' as:
*Main> (fix fact') 5
120
where in evaluating (fix fact') repeatedly applies fact' itself until we reach the same function, which we then call with the value 5. We can see this in:
fix fact'
= fact' (fix fact')
= (\rec n -> if n == 0 then 1 else n * rec (n-1)) (fix fact')
= \n -> if n == 0 then 1 else n * fix fact' (n-1)
= \n -> if n == 0 then 1 else n * fact' (fix fact') (n-1)
= \n -> if n == 0 then 1
else n * (\rec n' -> if n' == 0 then 1 else n' * rec (n'-1)) (fix fact') (n-1)
= \n -> if n == 0 then 1
else n * (if n-1 == 0 then 1 else (n-1) * fix fact' (n-2))
= \n -> if n == 0 then 1
else n * (if n-1 == 0 then 1
else (n-1) * (if n-2 == 0 then 1
else (n-2) * fix fact' (n-3)))
= ...
So what does all this mean? depending on the function you're dealing with, you won't necessarily be able to use fix to compute the kind of fixed point you want. This is, to my knowledge, dependent on the function(s) in question. Not all functions have the kind of fixed point computed by fix!
*I've avoided talking about domain theory, as I believe it would only confuse an already subtle topic. If you're curious, fix finds a certain kind of fixed point, namely the least available fixed point of the poset the function is specified over.
Just for the record, it is possible to define the function findFixedPoint using fix.
As Raeez has pointed out, recursive functions can be defined in terms of fix.
The function that you are interested in can be recursively defined as:
findFixedPoint :: Eq a => (a -> a) -> a -> a
findFixedPoint f x =
case (f x) == x of
True -> x
False -> findFixedPoint f (f x)
This means that we can define it as fix ffp where ffp is:
ffp :: Eq a => ((a -> a) -> a -> a) -> (a -> a) -> a -> a
ffp g f x =
case (f x) == x of
True -> x
False -> g f (f x)
For a concrete example, let us assume that f is defined as
f = drop 1
It is easy to see that for every finite list l we have findFixedPoint f l == [].
Here is how fix ffp would work when the "value argument" is []:
(fix ffp) f []
= { definition of fix }
ffp (fix ffp) f []
= { f [] = [] and definition of ffp }
[]
On the other hand, if the "value argument" is [42], we would have:
fix ffp f [42]
= { definition of fix }
ffp (fix ffp) f [42]
= { f [42] =/= [42] and definition of ffp }
(fix ffp) f (f [42])
= { f [42] = [] }
(fix ffp) f []
= { see above }
[]

Resources