When does keras reset an LSTM state? - keras

I read all sorts of texts about it, and none seem to answer this very basic question. It's always ambiguous:
In a stateful = False LSTM layer, does keras reset states after:
Each sequence; or
Each batch?
Suppose I have X_train shaped as (1000,20,1), meaning 1000 sequences of 20 steps of a single value. If I make:
model.fit(X_train, y_train, batch_size=200, nb_epoch=15)
Will it reset states for every single sequence (resets states 1000 times)?
Or will it reset states for every batch (resets states 5 times)?

Cheking with some tests, I got to the following conclusion, which is according to the documentation and to Nassim's answer:
First, there isn't a single state in a layer, but one state per sample in the batch. There are batch_size parallel states in such a layer.
In a stateful=False case, all the states are resetted together after each batch.
A batch with 10 sequences would create 10 states, and all 10 states are resetted automatically after it's processed.
The next batch with 10 sequences will create 10 new states, which will also be resetted after this batch is processed
If all those sequences have length (timesteps) = 7, the practical result of these two batches is:
20 individual sequences, each with length 7
None of the sequences are related. But of course: the weights (not the states) will be unique for the layer, and will represent what the layer has learned from all the sequences.
A state is: Where am I now inside a sequence? Which time step is it? How is this particular sequence behaving since its beginning up to now?
A weight is: What do I know about the general behavior of all sequences I've seen so far?
In this case, there is also the same number of parallel states, but they will simply not be resetted at all.
A batch with 10 sequences will create 10 states that will remain as they are at the end of the batch.
The next batch with 10 sequences (it's required to be 10, since the first was 10) will reuse the same 10 states that were created before.
The practical result is: the 10 sequences in the second batch are just continuing the 10 sequences of the first batch, as if there had been no interruption at all.
If each sequence has length (timesteps) = 7, then the actual meaning is:
10 individual sequences, each with length 14
When you see that you reached the total length of the sequences, then you call model.reset_states(), meaning you will not continue the previous sequences anymore, now you will start feeding new sequences.

In Keras there are two modes for maintaining states:
The default mode (stateful = False) where the state is reset after each sample. This is because, samples/sequences in a batch are always independent and each sample/sequence has a separate state which is the representation of that sequence. Another implication is that all the signal required to process a sequence in a batch is contained in that batch and state from previous batch is not required.
The stateful mode where the state is never reset. It is up to the user to reset state before a new epoch, but Keras itself wont reset the state. In this mode the state is propagated from sample "i" of one batch to sample"i" of the next batch. Generally it is recommended to reset state after each epoch, as the state may grow for too long and become unstable. However in my experience with small size datasets (20,000- 40,000 samples) resetting or not resetting the state after an epoch does not make much of a difference to the end result. For bigger datasets it may make a difference.
Stateful model will be useful if you have patterns that span over 100s of time steps. Otherwise the default mode is sufficient. In my experience setting the batch size roughly equivalent to the size (time steps) of the patterns in the data also helps.
The stateful setup could be quite difficult to grasp at first. The key to understand is if there are very long sequences which need to be divided into small sub-sequences and processed, the sub-sequences are lined up across batches and not within a batch. Within a batch the sequences are always independent. However if state needs to be carried over from one sub-sequence to another, the sub-sequences should be lined up across batches i.e. Sequence i of batch n+1 is a continuation of sequence i from batch n and so on. So in this case the final state of sequence i from batch n is passed over as the initial state of sequence i in batch n+1.

In the doc of the RNN code you can read this :
Note on using statefulness in RNNs :
You can set RNN layers to be 'stateful', which means that the states
computed for the samples in one batch will be reused as initial states
for the samples in the next batch. This assumes a one-to-one mapping
between samples in different successive batches.
I know that this doesn't answer directly your question, but to me it confirms what I was thinking : when a LSTM is not stateful, the state is reset after every sample. They don't work by batches, the idea in a batch is that every sample is independant from each other.
So you have 1000 reset of the state for your example.

Everyone seems to be making it too confusing. Keras LSTM resets state after every batch.
Here is a good blog: https://machinelearningmastery.com/understanding-stateful-lstm-recurrent-neural-networks-python-keras/
Read LSTM State Within A Batch and Stateful LSTM for a One-Char to One-Char Mapping topics in this blog. It shows why it must reset it after batch only.

Expanding on #Nassim_Ben's answer, it is true that each sequence is considered independent for each instance of the batch. However, you need to keep in mind that the RNNs hidden state and cell memory get's passed along to the next cell for 20 steps. The hidden state and cell memory is typically set to zero for the very first cell in the 20 cells.
After the 20th cell, and after the hidden state (only, not cell memory) gets passed onto the layers above the RNN, the state gets reset. I'm going to assume that they mean cell memory and hidden state here.
So yes, it does get reset for all 1000 instances, however, considering that your batch_size=200, it gets reset 5 times, with each batch getting reset after they are done passing information through those 20 steps. Hopefully you got your head around this.
Here's a project I did where I had the same question. Pay special attention to cell 15 and it's explanation in the blob after cell 11. I kept appending letters because the state was getting reset otherwise.


PyTorch Lightning training stalling at the beginning of fourth batch

I am having an odd problem in PyTorch Lightning, which I'm using for finetuning a language model on a GPU. The first three training batches run very quickly (<1 second), then the fourth goes on for hours without finishing, and eventually I cancel the job. This is true whether I use batches of size 2 or 16.
Using the callbacks on_train_batch_start and on_train_batch_end to print 'batch started' and 'batch ended', I know that the first three batches have all completed, and the fourth doesn't reach the on_train_batch_start callback. This leads me to believe that the problem is somewhere in the DataLoader, since on_train_batch_start appears to be the first hook in the training loop, according to PyTorch Lightning's pseudocode.
I place some printing statements in my custom collate_fn for the DataLoader, and they all printed as well, so it appears that the problem arises sometime after collating occurs.
Does anyone have any idea what the issue could be or how I can interrogate the code further?

Driver scheduling (public transportation): enforcing 30 min break after 4 h of driving time

We're struggling with some aspects of the following problem:
a public transportation bus timetable consists of shifts (~ track sections) each with fixed start and end times
bus drivers need to be assigned to each of those shifts
[constraint in question] legal regulations demand that each bus driver has a 30 min break after 4 hours of driving (i.e. after driving shifts)
put differently, a driver accrues driving time when driving shifts that must not exceed 4h unless the driver takes a 30 min break in which case the accrued time is "reset to zero"
In summary, we need to track the accrued driving time of each driver in order to suppress shift assignments to enforce the 30 min break.
The underlying problem seems to sit halfway between a job shop and an assignment problem:
Like job shop problems, it has shifts (or tasks, jobs) with many no-overlap and precedence constraints between them...
...BUT our shifts (~tasks/jobs) are not pre-assigned to drivers; in contrast with job shop problems, the tasks (~shifts) need to be executed on specific machines (~drivers) and are therefore pre-assigned, so assigning them is not part of the problem
Like assignment tasks, we need to assign shifts to as few as possible drivers...
...BUT we also need to handle the aforementioned no-overlap and precedence constraints, that are not taken into account in assignment problems
So my question is, how to best model the above constraint in a constraint problem with the or-tools?
Thanks in advance!
One general technique for specifying patterns in constraint programming is the regular constraint (in Gecode, Choco, MiniZinc, among others, unsure of the status for or-tools), where patterns of variables are specified using finite automata (DFAs and NFAs) or regular expressions.
In your case, assuming that you have a sequence of variables representing what a certain driver does at each time-point, it is fairly straight-forward to specify an automaton that accepts any sequence of values that does not contain mora than four consecutive hours of driving. A sketch of such an automaton:
Driving states Dn representing n time units driving (for some resolution of time units), up to n=4 hours.
Break states DnBm for a break of length m after n time units of driving, up to m=30 minutes.
Start state is D0.
Driving: When driving 1 unit of time, move from state Dn to D(n+1), and from a break shorter than 30 minutes from DnBm to D(n+1).
Break of 1 unit of time, move from DnBm to DnB(m+1), unless the 30 minutes break time has been reached, for which the transition goes back to D0.
Other actions handled mostly as self-loops, depending on desired semantics.
Of course, details will vary for your specific use-case.

Estimating WCET of a task on Linux

I want to approximate the Worst Case Execution Time (WCET) for a set of tasks on linux. Most professional tools are either expensive (1000s $), or don't support my processor architecture.
Since, I don't need a tight bound, my line of thought is that I :
disable frequency scaling
disbale unnecesary background services and tasks
set the program affinity to run on a specified core
run the program for 50,000 times with various inputs
Profiling it and storing the total number of cycles it had completed to
Given the largest clock cycle count and knowing the core frequency, I can get an estimate
Is this is a sound Practical approach?
Secondly, to account for interference from other tasks, I will run the whole task set (40) tasks in parallel with each randomly assigned a core and do the same thing for 50,000 times.
Once I get the estimate, a 10% safe margin will be added to account for unforseeble interference and untested path. This 10% margin has been suggested in the paper "Approximation of Worst Case Execution time in Preepmtive Multitasking Systems" by Corti, Brega and Gross
Some comments:
1) Even attempting to compute worst case bounds in this way means making assumptions that there aren't uncommon inputs that cause tasks to take much more or even much less time. An extreme example would be a bug that causes one of the tasks to go into an infinite loop, or that causes the whole thing to deadlock. You need something like a code review to establish that the time taken will always be pretty much the same, regardless of input.
2) It is possible that the input data does influence the time taken to some extent. Even if this isn't apparent to you, it could happen because of the details of the implementation of some library function that you call. So you need to run your tests on a representative selection of real life data.
3) When you have got your 50K test results, I would draw some sort of probability plot - see e.g. http://www.itl.nist.gov/div898/handbook/eda/section3/normprpl.htm and links off it. I would be looking for isolated points that show that in a few cases some runs were suspiciously slow or suspiciously fast, because the code review from (1) said there shouldn't be runs like this. I would also want to check that adding 10% to the maximum seen takes me a good distance away from the points I have plotted. You could also plot time taken against different parameters from the input data to check that there wasn't any pattern there.
4) If you want to try a very sophisticated approach, you could try fitting a statistical distribution to the values you have found - see e.g. https://en.wikipedia.org/wiki/Generalized_Pareto_distribution. But plotting the data and looking at it is probably the most important thing to do.

Bi-Threaded processing in Matlab

I have a Large-Scale Gradient Descent optimization problem that I am running using Matlab. The code has got two parts:
A Sequential update part that fires every iteration that updates the parameter vector.
A validation error computation part that fires every 10 iterations or so using the parameter value at the end of the corresponding iteration in which its fired.
The way that I am running this now is to do (1) and (2) sequentially. But (2) takes a lot of time and its not the core part of my routine - I made it just to check the progress and plot the error of my model. Is it possible in Matlab to run (2) in a parallel manner to (1) ? Please note that (1) cannot be run in parallel since it performs sequential update. So a simple 'parfor' usage is not a solution, unless there is a really smart way of doing that.
I don't think Matlab has any way of multi-threading outside of the (rather restricted) parallel computing toolbox. There is a work over which may help you though:
Open 2 sessions of Matlab, sessions A and B (or instances, or workspaces, however you call it)
Matlab session A:
Calculate the 10 iterations of your sequential process (1)
Saves the result in a file (adequately and uniquely named)
Goes on to calculate the next 10 iterations (back to the top of this loop basically)
In parralel:
Matlab session B:
Check periodically for the existence of the file written by process A (define a timer that will do that at the time interval which make sense for your process, a few seconds or a few minutes ...)
If the file exist => load it then do the validation computation (your process (2)) and display/report the results.
note: This only works if process (1) doesn't need the result of process (2) to run its iterations, but if it is the case I don't know how you could parallelise anyway.
If you have multiple cores on your machine that should run smoothly, if you have a single core then the 2 sessions will have to share and you will see a performance impact.

Implementing Concurrent writes in the CRCW threading model

PRAM models for parallel computing come in the three main flavours: EREW , CREW, CRCW.
I can understand how EREW, CREW can be implemented on a multicore machine. But how
would one go about implementing the CRCW model on a multicore CPU ? Is it even a practical model, since concurrent writes are not possible and every basic parallel programming course
goes into great details into race conditions.
Essentially this means that trying to avoid race conditions and trying to implement concurrent
writes are two opposing goals.
First up: We know that the PRAM is a theoretical, or abstract machine. There are several simplifications made so that it may be used for analyzing/designing parallel algorithms.
Next, let's talk about the ways in which one may do 'concurrent writes' meaningfully.
Concurrent write memories are usually divided into subclasses, based on how they behave:
Priority based CW - Processors have a priority, and if multiple concurrent writes to the same location arrive, the write from the processor of highest priority gets committed to memory.
Arbitary CW - One processor's write is arbitrarily chosen for commit.
Common CW - Multiple concurrent writes to the same location are committed only if the values being written are the same. i.e. all writing processors must agree on the value being written.
Reduction CW - A reduction operator is applied on the multiple values being written. e.g. a summation, where multiple concurrent writes to the same location lead to the sum of the values being written to be committed to memory.
These subclasses lead to some interesting algorithms. Some of the examples I remember from class are:
A CRCW-PRAM where the concurrent write is achieved as a summation can sum an arbitrarily large number of integers in a single timestep. There is a processor for each integer in the input array. All processors write their value to the same location. Done.
Imagine a CRCW-PRAM where the memory commits concurrent writes only if the value written by all processors is the same. Now imagine N numbers A[1] ... A[N], whose maximum you need to find. Here's how you'd do it:
Step 1.
N2 processors will compare each value to each other value, and write the result to a 2D array:
parallel_for i in [1,N]
parallel_for j in [1,N]
if (A[i] >= A[j])
B[i,j] = 1
B[i,j] = 0
So in this 2D array, the column corresponding to the biggest number will be all 1's.
Step 2:
Find the column which has only 1's. And store the corresponding value as the max.
parallel_for i in [1,N]
M[i] = 1
parallel_for j in [1,N]
if (B[i,j] = 0)
M[i] = 0 // multiple concurrent writes of *same* value
if M[i]
max = A[i]
Finally, is it possible to implement for real?
Yes, it is possible. Designing, say, a register file, or a memory and associated logic, which has multiple write ports, and which arbitrates concurrent writes to the same address in a meaningful way (like the ways I described above) is possible. You can probably already see that based on the subclasses I mentioned. Whether or not it is practical, I cannot say. I can say that in my limited experience with computers (which involves mostly using general purpose hardware, like the Core Duo machine I'm currently sitting before), I haven't seen one in practice.
EDIT: I did find a CRCW implementation. The wikipedia article on PRAM describes a CRCW machine which can find the max of an array in 2 clock cycles (using the same algorithm as the one above). The description is in SystemVerilog and can be implemented in an FPGA.
