Haskell infinite list of Bernoulli distributed booleans - haskell

I need an list of biased, random booleans. Each boolean needs to have the same probability of being True (Bernoulli distributed). These booleans are passed to a function, which generates zero or more output booleans per input boolean. I need an infinite list, because I don't know in advance how many booleans are required to provide enough output. See the below (simplified) code:
import System.Random.MWC
import System.Random.MWC.Distributions
foo :: [Bool] -> [Bool] -- foo outputs zero or more Bools per input Bool
main = do
gen <- create
bits <- sequence . repeat $ bernoulli 0.25 gen
print . take 32 . foo $ bits
Unfortunately, this code just hangs at the second line of main. I guess that there is something non-lazy happening somewhere with Control.Monad.ST?
(I would be able to do something like this with System.Random.randoms, but the resulting values don't have the required distributions.)
Can I fix this while keep using the System.Random.MWC library? Or does this require me to switch to alternative implementations?

The mwc-random package provides two PrimMonad instances, one for IO and another for ST s. As long as an ST computation is parameterized over all state tags s, we can run the computation and extract the value with runST :: (forall s. ST s a) -> a. By itself this wouldn't be very useful since we'd lose the state: the seed of the random generator, but mwc-random also provides explicit ways to handle the seeds:
save :: PrimMonad m => Gen (PrimState m) -> m Seed
restore :: PrimMonad m => Seed -> m (Gen (PrimState m))
We can use these to make a computation that generates a stream of values from a computation that generates a single value, as long as the generator is in forall s. ST s.
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
import System.Random.MWC
import Control.Monad.ST
import System.Random.MWC.Distributions
randomStream :: forall s a. (forall s. GenST s -> ST s a) -> GenST s -> ST s [a]
randomStream item = go
where
go :: forall s. GenST s -> ST s [a]
go gen = do
x <- item gen
seed <- save gen
return (x:runST (restore seed >>= go))
With this we can write your example as
main = do
bits <- withSystemRandom (randomStream (bernoulli 0.25))
print . take 32 $ bits
We can actually build generators more sophisticated than using the same generator for each item in the stream. We could thread a state along the stream so that each value can depend on the results of the previous ones.
unfoldStream :: forall s a b. (forall s. b -> GenST s -> ST s (a, b)) -> b -> GenST s -> ST s [a]
unfoldStream item = go
where
go :: forall s. b -> GenST s -> ST s [a]
go b gen = do
(x,b') <- item b gen
seed <- save gen
return (x:runST (restore seed >>= go b'))
The following example stream has results that increase in likelihood every time the result is False.
import Control.Monad.Primitive
interesting :: (PrimMonad m) => Double -> Gen (PrimState m) -> m (Bool, Double)
interesting p gen = do
result <- bernoulli p gen
let p' = if result then p else p + (1-p)*0.25
return (result, p')
main = do
bits <- withSystemRandom (unfoldStream interesting 0)
print . take 32 $ bits

The culprit is sequence . repeat - this will hang for (almost?) every monad, since you must perform a potentially infinite number of effects.
The simplest solutions would be to use a different library - which may not be possible if you are relying on the quality of the numbers produced from mwc-random. The next simplest solution is to rewrite foo to have type [IO Bool] -> IO [Bool] and pass it repeat (bernoulli 0.25 gen) - this would allow foo to make the choice of when to stop executing the effects produced by the infinite list. But having your logic inside of IO is not very nice.
The standard trick when you need an infinite list of random numbers is to use a pure function f :: StdGen -> (Result, StdGen). Then unfoldr (Just . f) :: StdGen -> [Result], and the output is an infinite list. At first glance, it may appear that mwc-random only has monadic functions, and that there is no pure interface. However, that is not the case, because ST s is an instance of PrimMonad. You also have the functions converting a Gen to a Seed. Using these, you can get a pure RNG function for any monadic one:
{-# LANGUAGE RankNTypes #-}
import System.Random.MWC
import System.Random.MWC.Distributions
import Control.Monad.ST
import Data.List
pureRand :: (forall s . GenST s -> ST s t) -> Seed -> (t, Seed)
pureRand f s = runST $ do
s' <- restore s
r <- f s'
s'' <- save s'
return (r, s'')
pureBernoulli :: Double -> Seed -> (Bool, Seed)
pureBernoulli a = pureRand (bernoulli a)
foo :: [Bool] -> [Bool]
foo = id
main = do
gen <- create >>= save
let bits = unfoldr (Just . pureBernoulli 0.25) gen
print . take 32 . foo $ bits
It is unfortunate that mwc-random doesn't expose this sort of interface by default but it is pretty easy to get to.
The other option is slightly more scary - use unsafe functions.
import System.IO.Unsafe
repeatM rand = go where
go = do
x <- rand
xs <- unsafeInterleaveIO go
return (x : xs)
main2 = do
gen <- create
bits <- repeatM (bernoulli 0.25 gen)
print . take 32 . foo $ bits
Naturally this comes with the usual caveats surrounding unsafe - use it only if you are exceedingly inconvenienced by the pure functions. unsafeInterleaveIO may reorder or never execute effects - if foo, for example, ignores one element, it will never be computed and the corresponding effect of updating the state stored in gen may not happen. For example, the following will print nothing:
snd <$> ((,) <$> unsafeInterleaveIO (putStrLn "Hello") <*> return ())

Related

How to sample RVarT in IO

I am having difficulties wrapping my brain around RVarT in random-fu. Just as a mental exercise I am trying to generate Maybe x randomly and combining them in Maybe (x, x), using monad transformers
I have manged to pull this off, which seems intuitive to me
maybeRandSome :: (MaybeT RVar) Int
maybeRandSome = lift $ return 1
maybeRandNone :: (MaybeT RVar) Int
maybeRandNone = MaybeT . return $ Nothing
maybeTwoRands :: (MaybeT RVar) (Int, Int)
maybeTwoRands =
do
x <- maybeRandSome
y <- maybeRandNone
return (x, y)
And can sample them in IO doing this
> sample $ runMaybeT maybeTwoRands
Nothing
However I cannot figure out if the reverse is possible:
reverseMaybeRandSome :: (RVarT Maybe) Int
reverseMaybeRandSome = lift $ Just 1
reverseMaybeRandNone :: (RVarT Maybe) Int
reverseMaybeRandNone = lift Nothing
reverseMaybeTwoRands :: (RVarT Maybe) (Int, Int)
reverseMaybeTwoRands =
do
x <- Random.sample reverseMaybeRandSome
y <- Random.sample reverseMaybeRandNone
return (x, y)
Which requires me to lift from Maybe m to MonadRandom m somehow, and I can't figure out if that makes sense or if I am doing something unsound to begin with.
Yes, you're pretty much doing something unsound. MaybeT m a is isomorphic to m (Maybe a) for any monad, including m = RVar, so a MaybeT RVar a is really just an RVar (Maybe a), which is a representation of a random variable taking values in Maybe a. Given this, it's easy enough to imagine sampling two Maybe a-valued random variables and combining them into a Maybe (a,a)-value random variable in the usual manner (i.e., if either or both are Nothing, the result is Nothing, and if they're Just x and Just y respectively, the result is Just (x,y)). That's what your first chunk of code is doing.
However, RVarT Maybe a is different. It's a a-valued (not Maybe a-valued) random variable that can use the facilities of the base Maybe monad in generating its values, provided they can be lifted in some sensible way to the final monad in which the "randomness" of the random variable is realized.
To understand what this means, we have to take a more detailed look at the types RVar and RVarT.
The type RVar a represents an a-valued random variable. In order to actually turn this representation into a real random value, you have to run it with:
runRVar :: RandomSource m s => RVar a -> s -> m a
This is a little too general, so imagine it being specialized to:
runRVar :: RVar a -> StdRandom -> IO a
Note that StdRandom is the only valid value of StdRandom here, so we'll always write runRVar something StdRandom, which can also be written sample something.
With this specialization, you should view an RVar a as a monadic recipe for constructing a random variable using a limited set of randomization primitives that runRVar converts into IO actions that realize the randomization primitives with respect to a global random number generator. This conversion to IO actions is what allows the recipe to generate an actual sampled random value. If you're interested, you can find the limited set of randomization primitives in Data.Random.Internal.Source.
Similarly, RVarT n a is also an a-valued random variable (i.e., a recipe for constructing a random variable using a limited set of randomization primitives) that also has access to the "facilities of another base monad n". This recipe can be run inside any final monad that can realize both the randomization primitives and the facilities of the base monad n. In the general case, you run it with:
runRVarTWith :: MonadRandom m =>
(forall t. n t -> m t) -> RVarT n a -> s -> m a
which takes an explicit lifting function that explains how to lift the facilities of the base monad n to the final monad m.
If the base monad n is Maybe, then it's "facilities" are the ability to signal an error or failed computation. You might use those facilities to construct the following somewhat silly random variable:
sqrtNormal :: RVarT Maybe Double
sqrtNormal = do
z <- stdNormalT
if z < 0
then lift Nothing -- use Maybe facilities to signal error
else return $ sqrt z
Note that, critically, sqrtNormal does not represent a Maybe Double-valued random variable to be generated. Instead it represents Double-valued random variable whose generation can fail via the facilities of the base Maybe monad.
In order to realize this random variable (i.e., sample it), we need to run it in an appropriate final monad. The final monad needs to support both the randomization primitives and an appropriately lifted notion of failure from the Maybe monad.
IO works fine, if the appropriate notion of failure is a runtime error:
liftMaybeToIO :: Maybe a -> IO a
liftMaybeToIO Nothing = error "simulation failed!"
liftMaybeToIO (Just x) = return x
after which:
main1 :: IO ()
main1 = print =<< runRVarTWith liftMaybeToIO sqrtNormal StdRandom
generates the square root of a positive standard Gaussian about half the time and throws a runtime error the other half.
If you want to capture failure in a pure form (as a Maybe, for example), then you need to consider realizing the RVar in an appropriate monad. The monad:
MaybeT IO a
will do the trick. It's isomorphic to IO (Maybe a), so it has IO facilities available (needed to realize the randomization primitives) and is capable of signaling failure by returning Nothing. If we write:
main2 :: IO ()
main2 = print =<< runMaybeT act
where act :: MaybeT IO Double
act = sampleRVarTWith liftMaybe sqrtNormal
we'll get an error that there's no instance for MonadRandom (MaybeT IO). We can create one as follows:
import Control.Monad.Trans (liftIO)
instance MonadRandom (MaybeT IO) where
getRandomPrim = liftIO . getRandomPrim
together with an appropriate lifting function:
liftMaybe :: Maybe a -> MaybeT IO a
liftMaybe = MaybeT . return
After which, main2 will return Nothing about half the time and Just the square root of a positive Gaussian the other half.
The full code:
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE FlexibleInstances #-}
import Control.Monad.Trans (liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..))
import Data.Random
import Data.Random.Lift
import Data.Random.Internal.Source
sqrtNormal :: RVarT Maybe Double
sqrtNormal = do
z <- stdNormalT
if z < 0
then lift Nothing -- use Maybe facilities to signal error
else return $ sqrt z
liftMaybeToIO :: Maybe a -> IO a
liftMaybeToIO Nothing = error "simulation failed!"
liftMaybeToIO (Just x) = return x
main1 :: IO ()
main1 = print =<< runRVarTWith liftMaybeToIO sqrtNormal StdRandom
instance MonadRandom (MaybeT IO) where
getRandomPrim = liftIO . getRandomPrim
main2 :: IO ()
main2 = print =<< runMaybeT act
where act :: MaybeT IO Double
act = runRVarTWith liftMaybe sqrtNormal StdRandom
liftMaybe :: Maybe a -> MaybeT IO a
liftMaybe = MaybeT . return
The way this would all apply to your second example would look something like this, which will always print Nothing:
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE FlexibleInstances #-}
import Control.Monad.Trans (liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..))
import Data.Random
import Data.Random.Lift
import Data.Random.Internal.Source
reverseMaybeRandSome :: RVarT Maybe Int
reverseMaybeRandSome = return 1
reverseMaybeRandNone :: RVarT Maybe Int
reverseMaybeRandNone = lift Nothing
reverseMaybeTwoRands :: RVarT Maybe (Int, Int)
reverseMaybeTwoRands =
do
x <- reverseMaybeRandSome
y <- reverseMaybeRandNone
return (x, y)
instance MonadRandom (MaybeT IO) where
getRandomPrim = liftIO . getRandomPrim
runRVarTMaybe :: RVarT Maybe a -> IO (Maybe a)
runRVarTMaybe act = runMaybeT $ runRVarTWith liftMaybe act StdRandom
where
liftMaybe :: Maybe a -> MaybeT IO a
liftMaybe = MaybeT . return
main :: IO ()
main = print =<< runRVarTMaybe reverseMaybeTwoRands

Unexpected memory growth with Control.Monad foldM

I have the following code, which has been stripped down and is I think as minimal as possible that has some very odd behaviour.
The code consists of two source files:
One to define some data:
module MyFunction where
data MyFunction =
MyFunction {
functionNumber :: Int,
functionResult :: IO String
}
makeMyFunction :: Show a => Int -> IO a -> MyFunction
makeMyFunction number result = MyFunction {
functionNumber = number,
functionResult = result >>= return . show }
And the other is Main:
module Main (main) where
import System.CPUTime (getCPUTime)
import Data.List (foldl')
import Data.Foldable (foldlM)
import Control.Monad (foldM)
import MyFunction
exampleFunction = do
--let x = foldl' (\a b -> a `seq` (a + b)) 0 [1..20000000] -- This works
--x <- foldlM (\a b -> a `seq` return (a + b)) 0 [1..20000000] -- This works (*)
x <- foldM (\a b -> a `seq` return (a + b)) 0 [1..20000000] -- This doesn't
print x
return ()
runFunction fn = do
result <- functionResult fn
duration <- getCPUTime
if result /= "()"
then putStrLn ""
else return ()
putStrLn (show (fromIntegral duration / (10^9)) ++ "ms")
return fn
main = do
runFunction (makeMyFunction 123 exampleFunction)
return ()
The code as above (compiled using GHC 7.10.3 with stack 1.0.0 with default flags) has a rapid increase in memory usage (exceeding 1GB), and takes typically 3.3 seconds.
If I make a changes to the code, for example:
Use one of the commented alternatives to the problem line
Take out any line from runFunction
The memory usage will remain minimal, and takes only about 1 second.
One feature that I think is most surprising to me is that replacing foldM with foldlM (which as far as I know foldM = foldlM) fixes the problem.
Also making changes to code that I don't see has any relationship to the problem lines of code also fixes the problem. For example removing the last putStrLn.
Another oddity is that if I merge the MyFunction module into the Main module, while it doesn't fix the problem, it actually causes foldlM to behave as foldM using excessive memory.
In the real code that this came from, I have a large number exampleFunctions, and there is significantly more Main code, and every so often I encounter this sort of unexplained memory usage from functions, that can usually be resolved by some sort of voodoo.
I'm looking for an explanation for the behaviour. If I know why this occurs I can then look into avoiding it. Could this be a compiler issue, or maybe just a misunderstanding on my part?
(*) I've highlighted the secondary issue that causes the same memory growth to occur with foldlM.
Here is foldlM from Foldable.hs (ghc)
-- | Monadic fold over the elements of a structure,
-- associating to the left, i.e. from left to right.
foldlM :: (Foldable t, Monad m) => (b -> a -> m b) -> b -> t a -> m b
foldlM f z0 xs = foldr f' return xs z0
where f' x k z = f z x >>= k
and foldM from Monad.hs
foldM :: (Foldable t, Monad m) => (b -> a -> m b) -> b -> t a -> m b
{-# INLINEABLE foldM #-}
{-# SPECIALISE foldM :: (a -> b -> IO a) -> a -> [b] -> IO a #-}
{-# SPECIALISE foldM :: (a -> b -> Maybe a) -> a -> [b] -> Maybe a #-}
foldM = foldlM
I placed these definitions to a separate module Test and tested the execution with and without INLINEABLE / SPESIALISE lines. Whatever the reason is, leaving out the SPECIALISE directives helped and the execution time and memory usage was like with foldlM.
After a little bit more digging, removing line
{-# SPECIALISE foldM :: (a -> b -> IO a) -> a -> [b] -> IO a #-}
effected the most.

Is there any safe way to generate a lazy list in IO?

I'd like to have a lazily-generated list of random numbers, and I managed to do it but with unsafeInterleaveIO:
rs :: Random a => (a,a) -> IO [a]
rs b = do
r <- randomRIO b
ns <- unsafeInterleaveIO $ rs b
return (r:ns)
Is there any safe way to accomplish this kind of values?
If you want "lazily generated elements with effects", one solution is to eschew the conventional list type and use a List monad transformer, like ListT from the pipes library:
import System.Random
import Control.Monad
import Pipes
import qualified Pipes.Prelude as P
rs :: rs :: (Random a, MonadPlus m, MonadIO m) => (a,a) -> m a
rs b = liftIO (randomRIO b) `mplus` rs b
main :: IO ()
main = runEffect $ enumerate (rs (1::Int,10)) >-> P.take 5 >-> P.print
The result is:
*Main> :main
7
2
5
6
4
However, this bars you from using the conventional list functions to consume the "effectful list"; you are thrust into the pipes ecosystem.
(Applicative folds from the foldl package can also be used to consume the list, with the impurely and foldM auxiliary functions.)
The MonadPlus interface should be used as much as possible while defining effectful lists, as described here. It makes the effectful lists more library-agnostic.
A better way would probably be to generate a seed and then calculate the list using randoms:
randomRsIO :: Random a => (a, a) -> IO [a]
randomRsIO b = do
g <- newStdGen
return $ randomRs b g
Or simply
randomRsIO b = fmap (randomRs b) newStdGen

Generate list of random values and also get a new generator

I'm using System.Random and the Random typeclass in my application to generate random numbers. However I'd like to generate a list of random Floats of arbitrary length with a function like randoms :: StdGen -> Int -> ([Float], StdGen)
Without the constraint of getting a new generator, I could easily write
randoms gen n = (take n $ randoms gen) :: [Float]
However this leaves me with the same random generator I started with, which means if I were to run this function twice in a row I'd get the same list unless I went and used the generator elsewhere to get a new one.
How can I generate an infinite (or arbitrary length) list of random values while also "refreshing" my random generator.
Well, let's look at the function you do have:
random :: StdGen -> (Float, StdGen) -- From System.Random
We can wrap this in the State monad to get a stateful computation:
state :: (s -> (a, s)) -> State s a -- From Control.Monad.Trans.State
random' :: State StdGen Float
random' = state random
Now, we can generate a bunch of floats just using replicateM:
replicateM :: (Monad m) => Int -> m a -> m [a] -- From Control.Monad
randoms' :: Int -> State StdGen [Float]
randoms' n = replicateM n random'
Finally, we unwrap the State to get back the explicit generator passing:
randoms :: Int -> StdGen -> ([Float], StdGen)
randoms n = runState (randoms' n)
If you combine all of these into one function definition you get:
randoms :: Int -> StdGen -> ([Float], StdGen)
randoms n = runState (replicateM n (state random))
In other words, we can describe the process as:
wrap random in the State monad
replicate it n times
unwrap it
This is why monads are such an important concept. Things that can seem tricky at first tend to be simple computations when viewed through the lens of the monad interface.
Gabriel's answer is correct and this is pretty much how the MonadRandom package is implemented (A state Monad parameterised with a random generator).
It saves you defining it every time, and it comes with a Monad transformer too, so you can transform any other Monad into one that can also produce random values.
Your example could be easily implemented as:
(runRand $ take n `fmap` getRandoms) :: RandomGen g => g -> ([Int], g)
StdGen happens to be an instance of of RandomGen, so you can simply plug it in and go!
An alternative without State or split, using mapAccumL from Data.List (and swap from Data.Tuple):
nRandoms n gen = mapAccumL(\g _ -> swap $ random g) gen [1..n]
though I have to say I don't have a convincing argument for why this should be better in any way.
You can define a function whose type matches the one you say you’d like to have, albeit more generally.
import System.Random
randoms' :: (RandomGen g, Random a) => g -> Int -> ([a], g)
randoms' g n =
let (g1, g2) = split g
in (take n $ randoms g1, g2)
Even though it uses split
split :: g -> (g, g)
The split operation allows one to obtain two distinct random number generators. This is very useful in functional programs (for example, when passing a random number generator down to recursive calls), but very little work has been done on statistically robust implementations of split …
it still doesn’t do what you want. (I use Bool in the examples below for easier visual comparison.)
ghci> g <- getStdGen
ghci> randoms' g 5 :: ([Bool], StdGen)
([False,False,False,True,False],1648254783 2147483398)
ghci> randoms' g 5 :: ([Bool], StdGen)
([False,False,False,True,False],1648254783 2147483398)
Note that the random arrays are the same.
Although the function goes to the trouble of splitting the generator, we promptly discard it. Instead, make use of g2 by threading it to the subsequent call as in
ghci> let (a1,g2) = randoms' g 5 :: ([Bool], StdGen)
ghci> let (a2,_) = randoms' g2 5 :: ([Bool], StdGen)
ghci> (a1,a2)
([False,False,False,True,False],[True,True,True,False,True]
If your code is running in the IO monad, you can use setStdGen to replace the global random number generator at the end, as in
myAction :: Int -> IO ([Float],[Float])
myAction n = do
g <- getStdGen
let (f1,g2) = randoms' g n
let (f2,g3) = randoms' g2 n
setStdGen g3
return (f1, f2)
Threading state around is awkward and error-prone. Consider using State or ST if you have lots of repeated boilerplate.

Constructing efficient monad instances on `Set` (and other containers with constraints) using the continuation monad

Set, similarly to [] has a perfectly defined monadic operations. The problem is that they require that the values satisfy Ord constraint, and so it's impossible to define return and >>= without any constraints. The same problem applies to many other data structures that require some kind of constraints on possible values.
The standard trick (suggested to me in a haskell-cafe post) is to wrap Set into the continuation monad. ContT doesn't care if the underlying type functor has any constraints. The constraints become only needed when wrapping/unwrapping Sets into/from continuations:
import Control.Monad.Cont
import Data.Foldable (foldrM)
import Data.Set
setReturn :: a -> Set a
setReturn = singleton
setBind :: (Ord b) => Set a -> (a -> Set b) -> Set b
setBind set f = foldl' (\s -> union s . f) empty set
type SetM r a = ContT r Set a
fromSet :: (Ord r) => Set a -> SetM r a
fromSet = ContT . setBind
toSet :: SetM r r -> Set r
toSet c = runContT c setReturn
This works as needed. For example, we can simulate a non-deterministic function that either increases its argument by 1 or leaves it intact:
step :: (Ord r) => Int -> SetM r Int
step i = fromSet $ fromList [i, i + 1]
-- repeated application of step:
stepN :: Int -> Int -> Set Int
stepN times start = toSet $ foldrM ($) start (replicate times step)
Indeed, stepN 5 0 yields fromList [0,1,2,3,4,5]. If we used [] monad instead, we would get
[0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5]
instead.
The problem is efficiency. If we call stepN 20 0 the output takes a few seconds and stepN 30 0 doesn't finish within a reasonable amount of time. It turns out that all Set.union operations are performed at the end, instead of performing them after each monadic computation. The result is that exponentially many Sets are constructed and unioned only at the end, which is unacceptable for most tasks.
Is there any way around it, to make this construction efficient? I tried but without success.
(I even suspect that there could be some kinds of theoretical limits following from Curry-Howard isomorphism and Glivenko's theorem. Glivenko's theorem says that for any propositional tautology φ the formula ¬¬φ can be proved in intuitionistic logic. However, I suspect that the length of the proof (in normal form) can be exponentially long. So, perhaps, there could be cases when wrapping a computation into the continuation monad will make it exponentially longer?)
Monads are one particular way of structuring and sequencing computations. The bind of a monad cannot magically restructure your computation so as to happen in a more efficient way. There are two problems with the way you structure your computation.
When evaluating stepN 20 0, the result of step 0 will be computed 20 times. This is because each step of the computation produces 0 as one alternative, which is then fed to the next step, which also produces 0 as alternative, and so on...
Perhaps a bit of memoization here can help.
A much bigger problem is the effect of ContT on the structure of your computation. With a bit of equational reasoning, expanding out the result of replicate 20 step, the definition of foldrM and simplifying as many times as necessary, we can see that stepN 20 0 is equivalent to:
(...(return 0 >>= step) >>= step) >>= step) >>= ...)
All parentheses of this expression associate to the left. That's great, because it means that the RHS of each occurrence of (>>=) is an elementary computation, namely step, rather than a composed one. However, zooming in on the definition of (>>=) for ContT,
m >>= k = ContT $ \c -> runContT m (\a -> runContT (k a) c)
we see that when evaluating a chain of (>>=) associating to the left, each bind will push a new computation onto the current continuation c. To illustrate what is going on, we can use again a bit of equational reasoning, expanding out this definition for (>>=) and the definition for runContT, and simplifying, yielding:
setReturn 0 `setBind`
(\x1 -> step x1 `setBind`
(\x2 -> step x2 `setBind` (\x3 -> ...)...)
Now, for each occurrence of setBind, let's ask ourselves what the RHS argument is. For the leftmost occurrence, the RHS argument is the whole rest of the computation after setReturn 0. For the second occurrence, it's everything after step x1, etc. Let's zoom in to the definition of setBind:
setBind set f = foldl' (\s -> union s . f) empty set
Here f represents all the rest of the computation, everything on the right hand side of an occurrence of setBind. That means that at each step, we are capturing the rest of the computation as f, and applying f as many times as there are elements in set. The computations are not elementary as before, but rather composed, and these computations will be duplicated many times.
The crux of the problem is that the ContT monad transformer is transforming the initial structure of the computation, which you meant as a left associative chain of setBind's, into a computation with a different structure, ie a right associative chain. This is after all perfectly fine, because one of the monad laws says that, for every m, f and g we have
(m >>= f) >>= g = m >>= (\x -> f x >>= g)
However, the monad laws do not impose that the complexity remain the same on each side of the equations of each law. And indeed, in this case, the left associative way of structuring this computation is a lot more efficient. The left associative chain of setBind's evaluates in no time, because only elementary subcomputations are duplicated.
It turns out that other solutions shoehorning Set into a monad also suffer from the same problem. In particular, the set-monad package, yields similar runtimes. The reason being, that it too, rewrites left associative expressions into right associative ones.
I think you have put the finger on a very important yet rather subtle problem with insisting that Set obeys a Monad interface. And I don't think it can be solved. The problem is that the type of the bind of a monad needs to be
(>>=) :: m a -> (a -> m b) -> m b
ie no class constraint allowed on either a or b. That means that we cannot nest binds on the left, without first invoking the monad laws to rewrite into a right associative chain. Here's why: given (m >>= f) >>= g, the type of the computation (m >>= f) is of the form m b. A value of the computation (m >>= f) is of type b. But because we can't hang any class constraint onto the type variable b, we can't know that the value we got satisfies an Ord constraint, and therefore cannot use this value as the element of a set on which we want to be able to compute union's.
Recently on Haskell Cafe Oleg gave an example how to implement the Set monad efficiently. Quoting:
... And yet, the efficient genuine Set monad is possible.
...
Enclosed is the efficient genuine Set monad. I wrote it in direct style (it seems to be faster, anyway). The key is to use the optimized choose function when we can.
{-# LANGUAGE GADTs, TypeSynonymInstances, FlexibleInstances #-}
module SetMonadOpt where
import qualified Data.Set as S
import Control.Monad
data SetMonad a where
SMOrd :: Ord a => S.Set a -> SetMonad a
SMAny :: [a] -> SetMonad a
instance Monad SetMonad where
return x = SMAny [x]
m >>= f = collect . map f $ toList m
toList :: SetMonad a -> [a]
toList (SMOrd x) = S.toList x
toList (SMAny x) = x
collect :: [SetMonad a] -> SetMonad a
collect [] = SMAny []
collect [x] = x
collect ((SMOrd x):t) = case collect t of
SMOrd y -> SMOrd (S.union x y)
SMAny y -> SMOrd (S.union x (S.fromList y))
collect ((SMAny x):t) = case collect t of
SMOrd y -> SMOrd (S.union y (S.fromList x))
SMAny y -> SMAny (x ++ y)
runSet :: Ord a => SetMonad a -> S.Set a
runSet (SMOrd x) = x
runSet (SMAny x) = S.fromList x
instance MonadPlus SetMonad where
mzero = SMAny []
mplus (SMAny x) (SMAny y) = SMAny (x ++ y)
mplus (SMAny x) (SMOrd y) = SMOrd (S.union y (S.fromList x))
mplus (SMOrd x) (SMAny y) = SMOrd (S.union x (S.fromList y))
mplus (SMOrd x) (SMOrd y) = SMOrd (S.union x y)
choose :: MonadPlus m => [a] -> m a
choose = msum . map return
test1 = runSet (do
n1 <- choose [1..5]
n2 <- choose [1..5]
let n = n1 + n2
guard $ n < 7
return n)
-- fromList [2,3,4,5,6]
-- Values to choose from might be higher-order or actions
test1' = runSet (do
n1 <- choose . map return $ [1..5]
n2 <- choose . map return $ [1..5]
n <- liftM2 (+) n1 n2
guard $ n < 7
return n)
-- fromList [2,3,4,5,6]
test2 = runSet (do
i <- choose [1..10]
j <- choose [1..10]
k <- choose [1..10]
guard $ i*i + j*j == k * k
return (i,j,k))
-- fromList [(3,4,5),(4,3,5),(6,8,10),(8,6,10)]
test3 = runSet (do
i <- choose [1..10]
j <- choose [1..10]
k <- choose [1..10]
guard $ i*i + j*j == k * k
return k)
-- fromList [5,10]
-- Test by Petr Pudlak
-- First, general, unoptimal case
step :: (MonadPlus m) => Int -> m Int
step i = choose [i, i + 1]
-- repeated application of step on 0:
stepN :: Int -> S.Set Int
stepN = runSet . f
where
f 0 = return 0
f n = f (n-1) >>= step
-- it works, but clearly exponential
{-
*SetMonad> stepN 14
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
(0.09 secs, 31465384 bytes)
*SetMonad> stepN 15
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
(0.18 secs, 62421208 bytes)
*SetMonad> stepN 16
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
(0.35 secs, 124876704 bytes)
-}
-- And now the optimization
chooseOrd :: Ord a => [a] -> SetMonad a
chooseOrd x = SMOrd (S.fromList x)
stepOpt :: Int -> SetMonad Int
stepOpt i = chooseOrd [i, i + 1]
-- repeated application of step on 0:
stepNOpt :: Int -> S.Set Int
stepNOpt = runSet . f
where
f 0 = return 0
f n = f (n-1) >>= stepOpt
{-
stepNOpt 14
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
(0.00 secs, 515792 bytes)
stepNOpt 15
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
(0.00 secs, 515680 bytes)
stepNOpt 16
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
(0.00 secs, 515656 bytes)
stepNOpt 30
fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
(0.00 secs, 1068856 bytes)
-}
I don't think your performance problems in this case are due to the use of Cont
step' :: Int -> Set Int
step' i = fromList [i,i + 1]
foldrM' f z0 xs = Prelude.foldl f' setReturn xs z0
where f' k x z = f x z `setBind` k
stepN' :: Int -> Int -> Set Int
stepN' times start = foldrM' ($) start (replicate times step')
gets similar performance to the Cont based implementation but occurs entirely in the Set "restricted monad"
I am not sure if I believe your claim about Glivenko's theorem leading to exponential increase in (normalized) proof size--at least in the Call-By-Need context. That is because we can arbitrarily reuse subproofs (and our logic is second order, we need only a single proof of forall a. ~~(a \/ ~a)). Proofs are not trees, they are graphs (sharing).
In general, you are likely to see performance costs from Cont wrapping Set but they can usually be avoided via
smash :: (Ord r, Ord k) => SetM r r -> SetM k r
smash = fromSet . toSet
I found out another possibility, based on GHC's ConstraintKinds extension. The idea is to redefine Monad so that it includes a parametric constraint on allowed values:
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RebindableSyntax #-}
import qualified Data.Foldable as F
import qualified Data.Set as S
import Prelude hiding (Monad(..), Functor(..))
class CFunctor m where
-- Each instance defines a constraint it valust must satisfy:
type Constraint m a
-- The default is no constraints.
type Constraint m a = ()
fmap :: (Constraint m a, Constraint m b) => (a -> b) -> (m a -> m b)
class CFunctor m => CMonad (m :: * -> *) where
return :: (Constraint m a) => a -> m a
(>>=) :: (Constraint m a, Constraint m b) => m a -> (a -> m b) -> m b
fail :: String -> m a
fail = error
-- [] instance
instance CFunctor [] where
fmap = map
instance CMonad [] where
return = (: [])
(>>=) = flip concatMap
-- Set instance
instance CFunctor S.Set where
-- Sets need Ord.
type Constraint S.Set a = Ord a
fmap = S.map
instance CMonad S.Set where
return = S.singleton
(>>=) = flip F.foldMap
-- Example:
-- prints fromList [3,4,5]
main = print $ do
x <- S.fromList [1,2]
y <- S.fromList [2,3]
return $ x + y
(The problem with this approach is in the case the monadic values are functions, such as m (a -> b), because they can't satisfy constraints like Ord (a -> b). So one can't use combinators like <*> (or ap) for this constrained Set monad.)

Resources