Stricter Strict State Monad - haskell

The strict state monad is defined using:
m >>= k = State $ \s ->
case runState m s of
(a, s') -> runState (k a) s'
But this can still leak memory, because a and s' are left unevaluated. For example, we might have a function f that takes a large object as input and quickly returns (a, s'), but as long as a is left unevaluated the input to f cannot be GC'ed.
One potential solution is to have f return seq a (a, s'), but this isn't always possible if we are using something like MonadRandom, and the state is encapsulated away from f. Is there a version that is defined like this:
m >>= k = State $ \s ->
case runState m s of
(!a, !s') -> runState (k a) s'
Does this exist in a library anywhere already?

According to the monad identity laws,
return a >>= const b = const b a = b
Thus in particular,
return undefined >>= const b = b
If the >>= operation were strict in the result value, that would break this law, so you shouldn't do that.
Suppose you instead do this:
m >>= k = State $ \s ->
case runState m s of
(a, !s') -> runState (k a) s'
Now we face another identity law:
m >>= return = m
For example,
return a >>= return = return a
So if return a >>= return is strict in the state, then we must also have return a strict in the state! So we need to redefine return as well:
return a = State $ \ !s -> (a, s)
Note that you don't really need to do any of this; if you want, you can use the usual strict state monad, and write things like
!_ <- get
in the spots where you want to force the state. You could even write an action to do this:
forceState :: Monad m => StateT s m ()
forceState = get >>= \ !_ -> return ()
Edit
Even this definition feels a little bit strange to me; I would expect the lambda to force the state, rather than the case. I'm not sure if not doing that leads to some kind of breakage, but it wouldn't surprise me if it did.

Related

Haskell instance of `bind` for a custom type

I'm trying to create an instance for bind operator (>>=) to the custom type ST a
I found this way to do it but I don't like that hardcoded 0.
Is there any way to implement it without having the hardcoded 0 and respecting the type of the function?
newtype ST a = S (Int -> (a, Int))
-- This may be useful to implement ">>=" (bind), but it is not mandatory to use it
runState :: ST a -> Int -> (a, Int)
runState (S s) = s
instance Monad ST where
return :: a -> ST a
return x = S (\n -> (x, n))
(>>=) :: ST a -> (a -> ST b) -> ST b
s >>= f = f (fst (runState s 0))
I often find it easier to follow such code with a certain type of a pseudocode rewrite, like this: starting with the
instance Monad ST where
return :: a -> ST a
return x = S (\n -> (x, n))
we get to the
runState (return x) n = (x, n)
which expresses the same thing exactly. It is now a kind of a definition through an interaction law that it must follow. This allows me to ignore the "noise"/wrapping around the essential stuff.
Similarly, then, we have
(>>=) :: ST a -> (a -> ST b) -> ST b
s >>= f = -- f (fst (runState s 0)) -- nah, 0? what's that?
--
-- runState (s >>= f) n = runState (f a) i where
-- (a, i) = runState s n
--
S $ \ n -> let (a, i) = runState s n in
runState (f a) i
because now we have an Int in sight (i.e. in scope), n, that will get provided to us when the combined computation s >>= f will "run". I mean, when it will runState.
Of course nothing actually runs until called upon from main. But it can be a helpful metaphor to hold in mind.
The way we've defined it is both the easiest and the most general, which is usually the way to go. There are more ways to make the types fit though.
One is to use n twice, in the input to the second runState as well, but this will leave the i hanging unused.
Another way is to flip the time arrow around w.r.t. the state passing, with
S $ \ n -> let (a, i2) = runState s i
(b, i ) = runState (f a) n
in (b, i2)
which is a bit weird to say the least. s still runs first (as expected for the s >>= f combination) to produce the value a from which f creates the second computation stage, but the state is being passed around in the opposite direction.
The most important thing to keep in mind is that your ST type is a wrapper around a function. What if you started your definition as (>>=) = \s -> \f -> S (\n -> ... )? It might be (ok, is) a bit silly to write separate lambdas for the s and f parameters there, but I did it to show that they're not really any different from the n parameter. You can use it in your definition of (>>=).

Understanding the state argument in the State Monad

I'm trying so hard to wrap my head around the State Monad, and I do not understand the following:
Given the implementation of return and (>>=), when you say State $ \s ->...., where does s come from? I mean, when you start performing >>= ... >>=, doesn't it mean that somewhere in your beginning of the chain you somehow have to provide for that initial parameter?
newtype State s a=State { runState::s->(a,s) }
instance Monad (State s) where
return a=State $ \s->(a,s)
(>>=) m g=State $ \s -> let (a,s')= runState m s in
runState (g a) s'
In (>>=) you say State $ \s -> runState m s, and I do not get when is that initial (\s -> ...) argument (with a REAL argument) called?
Can someone explain, please?
Later Edit:
Can someone show me how would the initial state be set, let's say if it needs to get a value using getLine?
main::IO()
main=do
argument<-getLine
--how do i set initial state with argument?
m >> f1 >> f2 >> f3
when you say State $ \s ->...., where does s come from ?
It will come from the invocation, when runState will supply the initial state value to the state-monadic value, to run the combined computation it describes:
st = do { x <- get ; return (x+1) }
x = runState st 0 -- x is (1,0)
I also sense another possible misunderstanding on your part: you write: "when is that initial (\s -> ...) argument called?" There's no "initial" lambda: the lambdas are all nested inside!
do { a <- as; b <- bs; c <- foo b; return c }
translates as
as >>= (\a -> bs >>= (\b -> foo b >>= (\c -> return c)))
so it's not "initial", that's one combined all-enclosing lambda that is called with the initial state!
And then it will call
let (a,s1) = runState as s0
etc. with that "initial" as in the do block.
the do block does not perform any stateful computation - it only assembles some smaller stateful computations into one bigger stateful computation. At the do level, the actual state does not exist.
It would be simpler and maybe even more accurate if the monad was called "a stateful computation". Or "a function that takes state of type S and returns another state of the same type alongside its actual result". Then you could imagine >>= as "combines two functions of the aforementioned type into one, such that the state returned by the first one is be passed as a parameter to the second one".
State is just a wrapper around functions of type s -> (a, s). runState doesn't actually "run" anything; it just gives back the function wrapped by the State constructor. You can, however, compare runState to the ($) operator for functions.
($) f x = f x
runState (State f) s = f s
That makes (=<<) = flip (>>=) similar to (<<<) = (.); just rather than taking two functions and returning a third function, it takes a function (that returns a State) and a State and produces a second State.
However, we'll make a direct comparison of (>>=) to (>>>) = flip (.) so that the types align better. (Similarly, you could compare (.) to (=<<).)
-- f :: t -> a
-- g :: a -> b
f >>> g = \t -> let a = ($) f t
in ($) g a
-- m :: State s a
-- g :: a -> State s b
m >>= g = State $ \s -> let (a, s') = runState m s
in runState (g a) s'

Haskell State Monad - What would the input to lamba \s -> ... be?

In the tutorial Learn You a Haskell - chapter 'for-a-few-monads-more', section 'The State monad', it lists the following to define the State Monad:
newtype State s a = State { runState :: s -> (a,s) }
instance Monad (State s) where
return x = State $ \s -> (x,s)
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
Just need an answer to a simple question: What would the input to \s be (since State h = a function that takes a state and outputs a tuple (result, newState); implying that input to \s would just be that function)? Examples welcome
You can think of a value of State s a as being a computation which depends on some state parameter which is provided when the computation is run. You can do this by simply unwrapping the contained function and calling it e.g.
runState (return 1) "state"
=> (1, "state")
You can picture return x as meaning "Give me a state, and I'll give you back that state and x". Then you can think of x >>= f1 as "Give me a state, and I'll give that to x; once it returns a state and a value, I'll give those to f and pass what f gives me on to you."
Here's an analogy with function composition:
f, g, h :: a -> a
j = f . g . h :: a -> a
j is a function that takes an a, and returns an a. Along the way, that value is first given to h, whose output goes to g, whose output goes to f, whose output is returned.
Look at "composition" of functions that return State values.
f', g', h' :: a -> State s a
j' a = return a >>= h' >>= g' >>= f'
Think of State as being a way of "hiding" a function argument. You could compose them manually like this:
f'', g'', h'' :: a -> s -> (a, s)
-- We could make this point-free by leaving h curried, but this
-- is nicely consistent.
j'' a s = (uncurry f'') . (uncurry g'') . (uncurry h'') $ (a, s)
but the State monad effectively does that for you with its implementation of >>=.
Note that j' only takes an initial value, while j'' takes an initial value and an initial state. That's because the function that would take that state for j' is still wrapped up in a State value; we use runState to retrieve that function so that an initial state can be supplied to the first function in the stack.
(runState j') s0 -- parentheses optional, shown for emphasis

Haskell: Join on State Monad

How to formally calculate/interpret the following expression?
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0]
I understand the informal explanation, which says: first run the outer stateful computation and then the resulting one.
Well, that's quite strange to me since if I follow the join and >>= definitions, it looks to me like I have to start from the internal monad (push 10) as the parameter of the id, and then do... hmmmm... well... I'm not sure what.... in order to get what is supposedly the result:
((),[10,1,2,0,0,0])
However how to explain it by the formal definitions:
instance Monad (State s) where
return x = State $ \s -> (x,s)
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
and
join :: Monad m => m (m a) -> m a
join n = n >>= id
Also, the definition of the State Monad's bind (>>=) is quite hard to grasp as having some "intuitive"/visual meaning (as opposed to just a formal definition that would satisfy the Monad laws). Does it have a less formal and more intuitive meaning?
The classic definition of State is pretty simple.
newtype State s a = State {runState :: s -> (a,s) }
A State s a is a "computation" (actually just a function) that takes something of type s (the initial state) and produces something of type a (the result) and something of type s (the final state).
The definition you give in your question for >>= makes State s a a "lazy state transformer". This is useful for some things, but a little harder to understand and less well-behaved than the strict version, which goes like this:
m >>= f = State $ \s ->
case runState m s of
(x, s') -> runState (f x) s'
I've removed the laziness and also taken the opportunity to use a record selector rather than pattern matching on State.
What's this say? Given an initial state, I runState m s to get a result x and a new state s'. I apply f to x to get a state transformer, and then run that with initial state s'.
The lazy version just uses lazy pattern matching on the tuple. This means that the function f can try to produce a state transformer without inspecting its argument, and that transformer can try to run without looking at the initial state. You can use this laziness in some cases to tie recursive knots, implement funny functions like mapAccumR, and use state in lazy incremental stream processing, but most of the time you don't really want/need that.
Lee explains pretty well what join does, I think.
If you specialise the type of join for State s you get:
join :: State s (State s a) -> State s a
so given a stateful computation which returns a result which is another stateful computation, join combines them into a single one.
The definition of push is not given in your question but I assume it looks like:
push :: a -> State [a] ()
push x = modify (x:)
along with some State type like
data State s a = State (s -> (a, s))
A value of State s a is a function which, given a value for the current state of type s returns a pair containing a result of type a and a new state value. Therefore
State $ \s -> (push 10,1:2:s)
has type State [Int] (State [Int] ()) (or some other numeric type other than Int. The outer State function returns as its result another State computation, and updates the state to have the values 1 and 2 pushed onto it.
An implementation of join for this State type would look like:
join :: State s (State s a) -> State s a
join outer = State $ \s ->
let (inner, s') = runState outer s
in runState inner s'
so it constructs a new stateful computation which first runs the outer computation to return a pair containing the inner computation and the new state. The inner computation is then run with the intermediate state.
If you plug your example into this definition then
outer = (State $ \s -> (push 10,1:2:s))
s = [0,0,0]
inner = push 10
s' = [1,2,0,0,0]
and the result is therefore the result of runState (push 10) [1,2,0,0,0] which is ((),[10,1,2,0,0,0])
You mentioned following the definitions for join and >>=, so, let's try that.
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0] = ?
The definitions are, again
instance Monad (State s) where
-- return :: a -> State s a
return x = State $ \s -> (x,s)
so for x :: a, State $ \s -> (x,s) :: State s a; (*) ---->
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
join m = m >>= id
and runState :: State s a -> s -> (a, s), i.e. it should be (*) <----
runState (State g) s = g s. So, following the definitions we have
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0]
= runState (State g) [0,0,0]
where (State g) = join (State $ \s -> (push 10,1:2:s))
= (State $ \s -> (push 10,1:2:s)) >>= id
-- (State h ) >>= f
= State $ \s -> let (a, newState) = h s
(State g) = id a
h s = (push 10,1:2:s)
in g newState
= State $ \s -> let (a, newState) = (push 10,1:2:s)
(State g) = a
in g newState
= State $ \s -> let (State g) = push 10
in g (1:2:s)
Now, push 10 :: State s a is supposed to match with State g where g :: s -> (a, s); most probably it's defined as push 10 = State \s-> ((),(10:) s); so we have
= State $ \s -> let (State g) = State \s-> ((),(10:) s)
in g (1:2:s)
= State $ \s -> let g s = ((),(10:) s)
in g (1:2:s)
= State $ \s -> ((),(10:) (1:2:s))
= runState (State $ \s -> ((),(10:) (1:2:s)) ) [0,0,0]
= (\s -> ((),(10:) (1:2:s))) [0,0,0]
= ((), 10:1:2:[0,0,0])
. So you see that push 10 is first produced as a result-value (with (a, newState) = (push 10,1:2:s)); then it is treated as the computation-description of type State s a, so is run last (not first, as you thought).
As Lee describes, join :: State s (State s a) -> State s a; the meaning of this type is, a computation of type State s (State s a) is one that produces State s a as its result-value, and that is push 10; we can run it only after we get hold of it.

Understanding State Monad

Looking at Learn You a Haskell's definition of the State Monad:
instance Monad (State s) where
return x = State $ \s -> (x,s)
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
I don't understand the types of h s and g newState in the lower right-hand side.
Can you please explain their types and what's going on?
State s a is a naming of a function---the "state transformer function"
s -> (a, s)
In other words, it takes an input state s and modifies that state while also returning a result, a. This forms a really general framework of "pure state". If our state is an integer, we can write a function which updates that integer and returns the new value---this is like a unique number source.
upd :: Int -> (Int, Int)
upd s = let s' = s + 1 in (s', s')
Here, a and s end up being the same type.
Now this is all fine and good, except that we're in trouble if we'd like to get two fresh numbers. For that we must somehow run upd twice.
The final result is going to be another state transformer function, so we're looking for a "state transformer transformer". I'll call it compose:
compose :: (s -> (a, s)) -- the initial state transformer
-> (a -> (s -> (b, s))) -- a new state transformer, built using the "result"
-- of the previous one
-> (s -> (b, s)) -- the result state transformer
This is a little hairy looking, but honestly it's fairly easy to write this function. The types guide you to the answer:
compose f f' = \s -> let (a, s') = f s
(b, s'') = f' a s'
in (b, s'')
You'll notice that the s-typed variables, [s, s', s''] "flow downward" indicating that state moves from the first computation through the second leading to the result.
We can use compose to build a function which gets two unique numbers using upd
twoUnique :: Int -> ((Int, Int), Int)
twoUnique = compose upd (\a s -> let (a', s') = upd s in ((a, a'), s'))
These are the basics of State. The only difference is that we recognize there's a common pattern going on inside of the compose function and we extract it. That pattern looks like
(>>=) :: State s a -> (a -> State s b ) -> State s b
(>>=) :: (s -> (a, s)) -> (a -> (s -> (b, s)) -> (s -> (b, s))
It's implemented the same way, too. We just need to "wrap" and "unwrap" the State bit---that's the purpose of State and runState
State :: (s -> (a, s)) -> State s a
runState :: State s a -> (s -> (a, s))
Now we can take compose and compare it to (>>=)
compose f f' = \s -> let (a, s') = f s
(b, s'') = f' a s'
in (b, s'')
(>>=) (State f) f' = State $ \s -> let (a, s') = f s
(b, s'') = runState (f' a) s'
in (b, s'')
The State Monad certainly is confusing the first time you see it. The first thing that's important to understand is its data declaration, which is
newtype State s a = State { runState :: s -> (a,s) }
so a State contains a function with the type s -> (a,s). We can think of this as a function acting on some sort of generator and returning a tuple of a value and a new generator. This is how random numbers work in Haskell, for example: s is the generator while a is the result of the function that takes a generator as input and outputs a random number a (say, of type Int, but it could just as easily be any other type).
Now let's talk about the instance declaration. Recall the type of (>>=) is
Monad m => m a -> (a -> m b) -> m b
In particular, we note that f should have the type a -> m b. In this case, m is State s, so the type of f should be a -> State s b. So now we can break down the instance declaration
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
Since f has the type a -> State s b, the type of State g must be State s b (i.e. g :: s -> (b,s)), and since h has the type s -> (a,s), we must have newState :: s. Thus the result of the bind expression is g newState, which is of type (b, s).
For further reading, here is a great article that helped me to understand the State Monad when I first came across it.
From the definition of the State monad at LYAH:
newtype State s a = State { runState :: s -> (a,s) }
This means the argument to the State data constructor is a function which takes a state and produces an a and a new state. Thus h in the example above is a function, and h s computes a and newState.
From Hoogle we see the definition of (>>=) is
(>>=) :: Monad m => m a -> (a -> m b) -> m b
which means f is also a function from a to State s b. Thus it makes sense to give f the argument a, and the result is a State. Just like h, g is the argument to a state constructor which takes a state (in this case newstate) and return a pair (a,newState2).
It might be more instructive to ask what (>>=) actually does: it lifts the function argument to a monad. A State is just a placeholder for a value depending on the current state, which is why the argument to the constructor depends on the state. Thus given a State "value", we first apply the state \s -> let (a, newState) = h s to get the corresponding value and a new state. Now we pass that value to the function (note that the types match up) and get a new state, i.e. a new function from a state to a value. Finally, we evaluate that state at newState to thread the state to the next part of the computation.

Resources