State Monad and 'put' function in Haskell - haskell

The Documentation about the State Monad says:
put :: s -> m ()
Replace the state inside the monad.
I cannot understand that. Does it mean that function replace state inside Monad? And The second issue: Why returned value is m () and not m s

The easiest way to understand the state monad, I think, is just to write your own and play around with it a bit. Study this code, play with other people's examples, and come back and review it from time to time until you're able to write it from memory:
-- | 'State' is just a newtype wrapper around the type #s -> (a, s)#.
-- These are functions which are fed a state value (type #s#) as input,
-- and produce as a pair of an #a# (the *result* of the state action)
-- and an #s# (the *new state* after the action).
--
-- The 'State' type is fundamentally a shortcut for chaining functions
-- of types like that.
newtype State s a = State { runState :: s -> (a, s) }
instance Functor (State s) where
fmap f (State g) = State $ \s0 ->
let (a, s1) = g s
in (f a, s1)
instance Applicative (State s) where
pure a = State $ \s -> (a, s)
State ff <*> State fa = State $ \s0 ->
let (s1, f) = ff s0
(s2, a) = fa s1
in (s2, f a)
instance Monad (State s) where
return = pure
State fa >>= f = State $ \s0 ->
let (s1, a) = fa s0
(s2, b) = runState (f a) s1
in (s2, b)
-- | 'get' is just a wrapper around a function that takes the
-- incoming #s# value and exposes it in the position where #a#
-- normally goes.
get :: State s s
get = State $ \s -> (s, s)
-- | 'put' is a wrapper around a function that discards the
-- the incoming #s# value and replaces it with another.
put :: s -> State s ()
put s = State $ \_ -> ((), s)
This is written directly in terms of a State type without using the MonadState class, which is a bit simpler to understand at first. As an exercise, once you feel comfortable with this, you can try writing it with the MonadState class.
And the second issue: Why returned value is m () and not m s?
It's mostly an arbitrary design choice, as far as I can tell. If I were designing the State type I might have written get and put like this, which is more similar to your expectation:
-- | Modify the incoming state by applying the given function to it.
-- Produces the previous, now discarded state as a result, which is
-- often useful.
modify :: (s -> s) -> State s s
modify f = State $ \s0 -> (s, f s)
-- Now 'get' and 'put' can be written in terms of 'modify':
get :: State s s
get = modify (\s -> s)
-- | This version of 'put' returns the original, discarded state,
-- which again is often useful.
put :: s -> State s s
put s = modify (\_ -> s)
If you have the standard 'get' and 'put' you can use that to write my modified 'put' as well:
-- | 'get' the incoming state, 'put' a new one in, and 'return' the old one.
replace :: s -> State s s
replace s1 = do
s0 <- get
put s1
return s0
So it doesn't make a big difference whether put produces () or s, anyway.

Related

Understanding the state argument in the State Monad

I'm trying so hard to wrap my head around the State Monad, and I do not understand the following:
Given the implementation of return and (>>=), when you say State $ \s ->...., where does s come from? I mean, when you start performing >>= ... >>=, doesn't it mean that somewhere in your beginning of the chain you somehow have to provide for that initial parameter?
newtype State s a=State { runState::s->(a,s) }
instance Monad (State s) where
return a=State $ \s->(a,s)
(>>=) m g=State $ \s -> let (a,s')= runState m s in
runState (g a) s'
In (>>=) you say State $ \s -> runState m s, and I do not get when is that initial (\s -> ...) argument (with a REAL argument) called?
Can someone explain, please?
Later Edit:
Can someone show me how would the initial state be set, let's say if it needs to get a value using getLine?
main::IO()
main=do
argument<-getLine
--how do i set initial state with argument?
m >> f1 >> f2 >> f3
when you say State $ \s ->...., where does s come from ?
It will come from the invocation, when runState will supply the initial state value to the state-monadic value, to run the combined computation it describes:
st = do { x <- get ; return (x+1) }
x = runState st 0 -- x is (1,0)
I also sense another possible misunderstanding on your part: you write: "when is that initial (\s -> ...) argument called?" There's no "initial" lambda: the lambdas are all nested inside!
do { a <- as; b <- bs; c <- foo b; return c }
translates as
as >>= (\a -> bs >>= (\b -> foo b >>= (\c -> return c)))
so it's not "initial", that's one combined all-enclosing lambda that is called with the initial state!
And then it will call
let (a,s1) = runState as s0
etc. with that "initial" as in the do block.
the do block does not perform any stateful computation - it only assembles some smaller stateful computations into one bigger stateful computation. At the do level, the actual state does not exist.
It would be simpler and maybe even more accurate if the monad was called "a stateful computation". Or "a function that takes state of type S and returns another state of the same type alongside its actual result". Then you could imagine >>= as "combines two functions of the aforementioned type into one, such that the state returned by the first one is be passed as a parameter to the second one".
State is just a wrapper around functions of type s -> (a, s). runState doesn't actually "run" anything; it just gives back the function wrapped by the State constructor. You can, however, compare runState to the ($) operator for functions.
($) f x = f x
runState (State f) s = f s
That makes (=<<) = flip (>>=) similar to (<<<) = (.); just rather than taking two functions and returning a third function, it takes a function (that returns a State) and a State and produces a second State.
However, we'll make a direct comparison of (>>=) to (>>>) = flip (.) so that the types align better. (Similarly, you could compare (.) to (=<<).)
-- f :: t -> a
-- g :: a -> b
f >>> g = \t -> let a = ($) f t
in ($) g a
-- m :: State s a
-- g :: a -> State s b
m >>= g = State $ \s -> let (a, s') = runState m s
in runState (g a) s'

is this piece of haskell code correct, if so why?

the haskell wiki (here : https://wiki.haskell.org/State_Monad ) says the state monad bind operator is defined like this :
(>>=) :: State s a -> (a -> State s b) -> State s b
(act1 >>= fact2) s = runState act2 is
where (iv,is) = runState act1 s
act2 = fact2 iv
however it seems incorrect to me as the result of the bind operator is a function wrapped in a constructor thus cannot be applied (I'm talking about this pattern : (act1 >>= fact2) s)
In short: A State object itself does not encapsulate the state, it encapsulates the change of a state.
Indeed, the State type is defined as:
newtype State s a = State { runState :: s -> (a, s) }
where runState is thus a function that takes a state s, and returns a result a and a new state.
The bind operator (>>=) :: State s a -> (a -> State s b) -> State s b basically "chains" state changes together. It thus takes one state changing function f1 :: s -> (a, s), and a function f2 :: a -> State s b, and thus creates a function g :: s -> (b, s) so to speak that is encapsulated in a State constructor. The second function f2 thus takes an a and returns such state changing function as well.
So the bind operator can be defined as:
(State f1) >>= f2 = State $ \i -> let (y, s) = f1 i in runState (f2 y) s
Here we have i as initial state, and we thus will first "chain" i through the f1 state changer. This returns then a 2-tuple: y is the "result" of that call, and s is the new state, we will then pass the result and the new state to f2. Note that here we do not make state changes at all, we only construct a State object that can do that. We thus postpone the real chaining.
If the State is however defined as above, then the piece of code does not match that definition, it defines it, like #HTWN says, as:
type State s a = s -> (a, s)
In that case, it is correct, given that runState is then the id function, since then:
(>>=) :: State s a -> (a -> State s b) -> State s b
(>>=) act1 fact2 = f
where f s = act2 is
where (iv,is) = act1 s
act2 = fact2 iv
In order to make it compatible with our State type, we thus add some logic to unwrap and wrap it in the State data constructor:
(>>=) :: State s a -> (a -> State s b) -> State s b
(>>=) act1 fact2 = State f
where f s = runState act2 is
where (iv,is) = runState act1 s
act2 = fact2 iv
then it is indeed correct. The main error is not wrapping it in a State data constructor.

Haskell: Join on State Monad

How to formally calculate/interpret the following expression?
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0]
I understand the informal explanation, which says: first run the outer stateful computation and then the resulting one.
Well, that's quite strange to me since if I follow the join and >>= definitions, it looks to me like I have to start from the internal monad (push 10) as the parameter of the id, and then do... hmmmm... well... I'm not sure what.... in order to get what is supposedly the result:
((),[10,1,2,0,0,0])
However how to explain it by the formal definitions:
instance Monad (State s) where
return x = State $ \s -> (x,s)
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
and
join :: Monad m => m (m a) -> m a
join n = n >>= id
Also, the definition of the State Monad's bind (>>=) is quite hard to grasp as having some "intuitive"/visual meaning (as opposed to just a formal definition that would satisfy the Monad laws). Does it have a less formal and more intuitive meaning?
The classic definition of State is pretty simple.
newtype State s a = State {runState :: s -> (a,s) }
A State s a is a "computation" (actually just a function) that takes something of type s (the initial state) and produces something of type a (the result) and something of type s (the final state).
The definition you give in your question for >>= makes State s a a "lazy state transformer". This is useful for some things, but a little harder to understand and less well-behaved than the strict version, which goes like this:
m >>= f = State $ \s ->
case runState m s of
(x, s') -> runState (f x) s'
I've removed the laziness and also taken the opportunity to use a record selector rather than pattern matching on State.
What's this say? Given an initial state, I runState m s to get a result x and a new state s'. I apply f to x to get a state transformer, and then run that with initial state s'.
The lazy version just uses lazy pattern matching on the tuple. This means that the function f can try to produce a state transformer without inspecting its argument, and that transformer can try to run without looking at the initial state. You can use this laziness in some cases to tie recursive knots, implement funny functions like mapAccumR, and use state in lazy incremental stream processing, but most of the time you don't really want/need that.
Lee explains pretty well what join does, I think.
If you specialise the type of join for State s you get:
join :: State s (State s a) -> State s a
so given a stateful computation which returns a result which is another stateful computation, join combines them into a single one.
The definition of push is not given in your question but I assume it looks like:
push :: a -> State [a] ()
push x = modify (x:)
along with some State type like
data State s a = State (s -> (a, s))
A value of State s a is a function which, given a value for the current state of type s returns a pair containing a result of type a and a new state value. Therefore
State $ \s -> (push 10,1:2:s)
has type State [Int] (State [Int] ()) (or some other numeric type other than Int. The outer State function returns as its result another State computation, and updates the state to have the values 1 and 2 pushed onto it.
An implementation of join for this State type would look like:
join :: State s (State s a) -> State s a
join outer = State $ \s ->
let (inner, s') = runState outer s
in runState inner s'
so it constructs a new stateful computation which first runs the outer computation to return a pair containing the inner computation and the new state. The inner computation is then run with the intermediate state.
If you plug your example into this definition then
outer = (State $ \s -> (push 10,1:2:s))
s = [0,0,0]
inner = push 10
s' = [1,2,0,0,0]
and the result is therefore the result of runState (push 10) [1,2,0,0,0] which is ((),[10,1,2,0,0,0])
You mentioned following the definitions for join and >>=, so, let's try that.
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0] = ?
The definitions are, again
instance Monad (State s) where
-- return :: a -> State s a
return x = State $ \s -> (x,s)
so for x :: a, State $ \s -> (x,s) :: State s a; (*) ---->
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
join m = m >>= id
and runState :: State s a -> s -> (a, s), i.e. it should be (*) <----
runState (State g) s = g s. So, following the definitions we have
runState (join (State $ \s -> (push 10,1:2:s))) [0,0,0]
= runState (State g) [0,0,0]
where (State g) = join (State $ \s -> (push 10,1:2:s))
= (State $ \s -> (push 10,1:2:s)) >>= id
-- (State h ) >>= f
= State $ \s -> let (a, newState) = h s
(State g) = id a
h s = (push 10,1:2:s)
in g newState
= State $ \s -> let (a, newState) = (push 10,1:2:s)
(State g) = a
in g newState
= State $ \s -> let (State g) = push 10
in g (1:2:s)
Now, push 10 :: State s a is supposed to match with State g where g :: s -> (a, s); most probably it's defined as push 10 = State \s-> ((),(10:) s); so we have
= State $ \s -> let (State g) = State \s-> ((),(10:) s)
in g (1:2:s)
= State $ \s -> let g s = ((),(10:) s)
in g (1:2:s)
= State $ \s -> ((),(10:) (1:2:s))
= runState (State $ \s -> ((),(10:) (1:2:s)) ) [0,0,0]
= (\s -> ((),(10:) (1:2:s))) [0,0,0]
= ((), 10:1:2:[0,0,0])
. So you see that push 10 is first produced as a result-value (with (a, newState) = (push 10,1:2:s)); then it is treated as the computation-description of type State s a, so is run last (not first, as you thought).
As Lee describes, join :: State s (State s a) -> State s a; the meaning of this type is, a computation of type State s (State s a) is one that produces State s a as its result-value, and that is push 10; we can run it only after we get hold of it.

Increment function for state monad in Haskell

I have the following defined State Monad, with which I am trying to implement an increment function:
data State a = State (Int -> (a, Int))
instance Monad State where
return x = State $ \s -> (x, s)
(State f) >>= k = State $ \s ->
let
(x, s') = f s
State f' = k x
in f' s'
get :: State Int
get = State $ \s -> (s, s)
put :: Int -> State ()
put s = State $ \_ -> ((), s)
I have done the following:
increment :: State ()
increment = do
a <- get
put(a+1)
And this appears to work.
Is this correct, and how can I verify that the state is indeed being incremented? Perhaps more generally, how do I use get an put?
You need some way to extract the inner function of the state. You can either do this by pattern matching on State a like you do in your bind definition or you can define State using record syntax data State a = State {runState :: Int -> (a, Int)}. Once you have runState you can easily test your increment function using runState increment 1. Your use of get and put seems to be just fine, not quite sure what you want to know there.
Also you should add an Applicative instance for State because Applicative will be a superclass of Monad as of ghc 7.10.
Let's add a few utility functions to this code:
-- Given a `State` computation and a starting state, run the computation
-- and obtain the result value and final state.
runState :: State a -> Int -> (a, Int)
runState (State f) init = f init
-- Given a `State` computation and a starting state, run the computation
-- and obtain the result value.
evalState :: State a -> Int -> a
evalState st i = fst (runState st i)
-- Given a `State` computation and a starting state, run the computation
-- and obtain the final state.
execState :: State a -> Int -> Int
execState st i = snd (runState st i)
Now, using one of these functions, how would you write a function that tests whether increment does in fact increment the state by one?

Understanding State Monad

Looking at Learn You a Haskell's definition of the State Monad:
instance Monad (State s) where
return x = State $ \s -> (x,s)
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
I don't understand the types of h s and g newState in the lower right-hand side.
Can you please explain their types and what's going on?
State s a is a naming of a function---the "state transformer function"
s -> (a, s)
In other words, it takes an input state s and modifies that state while also returning a result, a. This forms a really general framework of "pure state". If our state is an integer, we can write a function which updates that integer and returns the new value---this is like a unique number source.
upd :: Int -> (Int, Int)
upd s = let s' = s + 1 in (s', s')
Here, a and s end up being the same type.
Now this is all fine and good, except that we're in trouble if we'd like to get two fresh numbers. For that we must somehow run upd twice.
The final result is going to be another state transformer function, so we're looking for a "state transformer transformer". I'll call it compose:
compose :: (s -> (a, s)) -- the initial state transformer
-> (a -> (s -> (b, s))) -- a new state transformer, built using the "result"
-- of the previous one
-> (s -> (b, s)) -- the result state transformer
This is a little hairy looking, but honestly it's fairly easy to write this function. The types guide you to the answer:
compose f f' = \s -> let (a, s') = f s
(b, s'') = f' a s'
in (b, s'')
You'll notice that the s-typed variables, [s, s', s''] "flow downward" indicating that state moves from the first computation through the second leading to the result.
We can use compose to build a function which gets two unique numbers using upd
twoUnique :: Int -> ((Int, Int), Int)
twoUnique = compose upd (\a s -> let (a', s') = upd s in ((a, a'), s'))
These are the basics of State. The only difference is that we recognize there's a common pattern going on inside of the compose function and we extract it. That pattern looks like
(>>=) :: State s a -> (a -> State s b ) -> State s b
(>>=) :: (s -> (a, s)) -> (a -> (s -> (b, s)) -> (s -> (b, s))
It's implemented the same way, too. We just need to "wrap" and "unwrap" the State bit---that's the purpose of State and runState
State :: (s -> (a, s)) -> State s a
runState :: State s a -> (s -> (a, s))
Now we can take compose and compare it to (>>=)
compose f f' = \s -> let (a, s') = f s
(b, s'') = f' a s'
in (b, s'')
(>>=) (State f) f' = State $ \s -> let (a, s') = f s
(b, s'') = runState (f' a) s'
in (b, s'')
The State Monad certainly is confusing the first time you see it. The first thing that's important to understand is its data declaration, which is
newtype State s a = State { runState :: s -> (a,s) }
so a State contains a function with the type s -> (a,s). We can think of this as a function acting on some sort of generator and returning a tuple of a value and a new generator. This is how random numbers work in Haskell, for example: s is the generator while a is the result of the function that takes a generator as input and outputs a random number a (say, of type Int, but it could just as easily be any other type).
Now let's talk about the instance declaration. Recall the type of (>>=) is
Monad m => m a -> (a -> m b) -> m b
In particular, we note that f should have the type a -> m b. In this case, m is State s, so the type of f should be a -> State s b. So now we can break down the instance declaration
(State h) >>= f = State $ \s -> let (a, newState) = h s
(State g) = f a
in g newState
Since f has the type a -> State s b, the type of State g must be State s b (i.e. g :: s -> (b,s)), and since h has the type s -> (a,s), we must have newState :: s. Thus the result of the bind expression is g newState, which is of type (b, s).
For further reading, here is a great article that helped me to understand the State Monad when I first came across it.
From the definition of the State monad at LYAH:
newtype State s a = State { runState :: s -> (a,s) }
This means the argument to the State data constructor is a function which takes a state and produces an a and a new state. Thus h in the example above is a function, and h s computes a and newState.
From Hoogle we see the definition of (>>=) is
(>>=) :: Monad m => m a -> (a -> m b) -> m b
which means f is also a function from a to State s b. Thus it makes sense to give f the argument a, and the result is a State. Just like h, g is the argument to a state constructor which takes a state (in this case newstate) and return a pair (a,newState2).
It might be more instructive to ask what (>>=) actually does: it lifts the function argument to a monad. A State is just a placeholder for a value depending on the current state, which is why the argument to the constructor depends on the state. Thus given a State "value", we first apply the state \s -> let (a, newState) = h s to get the corresponding value and a new state. Now we pass that value to the function (note that the types match up) and get a new state, i.e. a new function from a state to a value. Finally, we evaluate that state at newState to thread the state to the next part of the computation.

Resources