How to improve this toy Jax optimizer code with while loops and saved history? - jax

I'm writing a custom optimizer I want JIT-able with Jax which features 1) breaking on maximum steps reached 2) breaking on a tolerance reached, and 3) saving the history of the steps taken. I'm relatively new to some of this stuff in Jax, but reading the docs I have this solution:
import jax, jax.numpy as jnp
#jax.jit
def optimizer(x, tol = 1, max_steps = 5):
def cond(arg):
step, x, history = arg
return (step < max_steps) & (x > tol)
def body(arg):
step, x, history = arg
x = x / 2 # simulate taking an optimizer step
history = history.at[step].set(x) # simulate saving current step
return (step + 1, x, history)
return jax.lax.while_loop(
cond,
body,
(0, x, jnp.full(max_steps, jnp.nan))
)
optimizer(10.) # works
My question is whether this can be improved in some way? In particular, is there a way to avoid pre-allocating the history? This isn't ideal since the real thing is alot more complicated than a single array and there's obviously the potential for wasted memory if tolerance is reached well before the maximum steps.

is there a way to avoid pre-allocating the history?
No, as I understand JAX
in JAX, 'type' includes shape, that is, the in and out data shape of the body function MUST be the same, otherwise, say dynamic grow history use jnp.vstack((history, x)), JAX will consider it as side effect.

There is a way, if you think the tolerance will be often reached before the maximum number of steps.
JAX implements sparse matrices (and pytrees of them), in jax.experimental.sparse. They will have the same shape as the maximum history size, and therefore satisfy the "fixed size" requirement for XLA, but of course, will only store nonzero elements in memory.

Related

Is vmap efficient as compared to batched ops?

I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
is vmap the correct/expected way to batch train models in (at least most of the time)?
is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.
vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.
If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M # v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.
So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.

How to get a 2D output from linear layer in pytorch?

I would like to project a tensor into a space with an additional dimension.
I tried
torch.nn.Linear(
in_features=num_inputs,
out_features=(num_inputs, num_additional),
)
But this results in an error
A workaround would be to
torch.nn.Linear(
in_features=num_inputs,
out_features=num_inputs*num_additional,
)
and then change the view the output
output.view(batch_size, num_inputs, num_additional)
But I imagine this workaround will get tricky to read, especially when a projection into more than one additional dimension is desired.
Is there a more direct way to code this operation?
Perhaps the source code for linear can be changed
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
To accept more dimensions for the weight and bias initialization, and F.linear seems like it would need to be replaced with a different function.
IMO the workaround you provided is already clear enough. However, if you want to express this as a single operation, you can always write your own module by subclassing torch.nn.Linear:
import numpy as np
import torch
class MultiDimLinear(torch.nn.Linear):
def __init__(self, in_features, out_shape, **kwargs):
self.out_shape = out_shape
out_features = np.prod(out_shape)
super().__init__(in_features, out_features, **kwargs)
def forward(self, x):
out = super().forward(x)
return out.reshape((len(x), *self.out_shape))
if __name__ == '__main__':
tmp = torch.empty((32, 10))
linear = MultiDimLinear(in_features=10, out_shape=(10, 10))
out = linear(tmp)
print(out.shape) # (32, 10, 10)
Another way would be to use torch.einsum
https://pytorch.org/docs/stable/generated/torch.einsum.html
torch.einsum can prevent summation across dimensions in tensor to tensor multiplication operations. This can allow separate multiplication operations to happen in parallel. [ I do not know if this would necessarily result in GPU efficiency; if the operations are still occurring in the same kernel. In fact, it may be slower https://github.com/pytorch/pytorch/issues/32591 ]
How this would work is to directly initialize the weight and bias tensors (look at source code for the torch linear layer for that code)
Say that the input (X) has dimensions (a, b), where a is the batch size.
Say that you want to pass this input through a series of classifiers, represented in a single weight tensor (W) with dimensions (c, d, e), where c is the number of classifiers, and e is the number of classes for the classifier
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 2)
torch.einsum('ab, cbe -> ace', x, w)
in the last line, a and b are the dimensions of the input as mentioned above. What might be the tricky part is c, b, and e are the dimensions of the classifiers weight tensor; I didn't use d, I used b instead. That is because the vector multiplication is happening along that dimension for the inputs tensor and the weight tensor. So that's why the left side of the einsum equation is ab, cbe. The right side of the einsum equation is simply what dimensions to exclude from summation.
The final dimensions we want is (a, c, e). a is the batch size, c is the number of classifiers, and e is the number of classes for each classifier. We do not want to add those values, so to preserve their separation, the left side of the equation is ace.
For those unfamiliar with einsum, this will be harder to read than the word around I created (though I highly recommend learning it, because it gets very easy and intuitive very fast even though it's a bit tricky at first https://www.youtube.com/watch?v=pkVwUVEHmfI )
However, for paralyzing certain operations (especially on GPU), it seems that einsum is the only way to do it. For example so that in my previous example, I didn't want to use a classification head yet, I just wanted to project to multiple dimensions.
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 4)
y = torch.einsum('ab, cbe -> ace', x, w)
And say I do a few other operations to y, perhaps some non linear operations, activations, etc.
z = f(y)
z will still have the dimensions 2, 5, 4. Batch size two, 5 hidden states per batch, and the dimension of those hidden states are 4.
And then I want to apply a classifier to each separate tensor.
w2 = torch.arange(4*2).view(4, 2)
final = torch.einsum('fgh, hj -> fgj', z, w2)
Quick refresh, 2 is the batch size, 5 is the number of classifier, and 2 is the number of outputs for each classifier.
The output dimensions, f, g, j (2, 5, 2) will not be summed across, and thus will be preserved in the output.
As cited in the github link, this may be slower than just using regular linear layers. There may be efficiencies in a very large number of parallel operations.

Unexpected solution using JiTCDDE

I'm trying to investigate the behavior of the following Delayed Differential Equation using Python:
y''(t) = -y(t)/τ^2 - 2y'(t)/τ - Nd*f(y(t-T))/τ^2,
where f is a cut-off function which is essentially equal to the identity when the absolute value of its argument is between 1 and 10 and otherwise is equal to 0 (see figure 1), and Nd, τ and T are constants.
For this I'm using the package JiTCDDE. This provides a reasonable solution to the above equation. Nevertheless, when I try to add a noise on the right hand side of the equation, I obtain a solution which stabilize to a non-zero constant after a few oscillations. This is not a mathematical solution of the equation (the only possible constant solution being equal to zero). I don't understand why this problem arises and if it is possible to solve it.
I reproduce my code below. Here, for the sake of simplicity, I substituted the noise with an high-frequency cosine, which is introduced in the system of equation as the initial condition for a dummy variable (the cosine could have been introduced directly in the system, but for a general noise this doesn't seem possible). To simplify further the problem, I removed also the term involving the f function, as the problem arises also without it. Figure 2 shows the plot of the function given by the code.
from jitcdde import jitcdde, y, t
import numpy as np
from matplotlib import pyplot as plt
import math
from chspy import CubicHermiteSpline
# Definition of function f:
def functionf(x):
return x/4*(1+symengine.erf(x**2-Bmin**2))*(1-symengine.erf(x**2-Bmax**2))
#parameters:
τ = 42.9
T = 35.33
Nd = 8.32
# Definition of the initial conditions:
dt = .01 # Time step.
totT = 10000. # Total time.
Nmax = int(totT / dt) # Number of time steps.
Vt = np.linspace(0., totT, Nmax) # Vector of times.
# Definition of the "noise"
X = np.zeros(Nmax)
for i in range(Nmax):
X[i]=math.cos(Vt[i])
past=CubicHermiteSpline(n=3)
for time, datum in zip(Vt,X):
regular_past = [10.,0.]
past.append((
time-totT,
np.hstack((regular_past,datum)),
np.zeros(3)
))
noise= lambda t: y(2,t-totT)
# Integration of the DDE
g = [
y(1),
-y(0)/τ**2-2*y(1)/τ+0.008*noise(t)
]
g.append(0)
DDE = jitcdde(g)
DDE.add_past_points(past)
DDE.adjust_diff()
data = []
for time in np.arange(DDE.t, DDE.t+totT, 1):
data.append( DDE.integrate(time)[0] )
plt.plot(data)
plt.show()
Incidentally, I noticed that even without noise, the solution seems to be discontinuous at the point zero (y is set to be equal to zero for negative times), and I don't understand why.
As the comments unveiled, your problem eventually boiled down to this:
step_on_discontinuities assumes delays that are small with respect to the integration time and performs steps that are placed on those times where the delayed components points to the integration start (0 in your case). This way initial discontinuities are handled.
However, implementing an input with a delayed dummy variable introduces a large delay into the system, totT in your case.
The respective step for step_on_discontinuities would be at totT itself, i.e., after the desired integration time.
Thus when you reach for time in np.arange(DDE.t, DDE.t+totT, 1): in your code, DDE.t is totT.
Therefore you have made a big step before you actually start integrating and observing which may seem like a discontinuity and lead to weird results, in particular you do not see the effect of your input, because it has already “ended” at this point.
To avoid this, use adjust_diff or integrate_blindly instead of step_on_discontinuities.

How to correctly implement a batch-input LSTM network in PyTorch?

This release of PyTorch seems provide the PackedSequence for variable lengths of input for recurrent neural network. However, I found it's a bit hard to use it correctly.
Using pad_packed_sequence to recover an output of a RNN layer which were fed by pack_padded_sequence, we got a T x B x N tensor outputs where T is the max time steps, B is the batch size and N is the hidden size. I found that for short sequences in the batch, the subsequent output will be all zeros.
Here are my questions.
For a single output task where the one would need the last output of all the sequences, simple outputs[-1] will give a wrong result since this tensor contains lots of zeros for short sequences. One will need to construct indices by sequence lengths to fetch the individual last output for all the sequences. Is there more simple way to do that?
For a multiple output task (e.g. seq2seq), usually one will add a linear layer N x O and reshape the batch outputs T x B x O into TB x O and compute the cross entropy loss with the true targets TB (usually integers in language model). In this situation, do these zeros in batch output matters?
Question 1 - Last Timestep
This is the code that i use to get the output of the last timestep. I don't know if there is a simpler solution. If it is, i'd like to know it. I followed this discussion and grabbed the relative code snippet for my last_timestep method. This is my forward.
class BaselineRNN(nn.Module):
def __init__(self, **kwargs):
...
def last_timestep(self, unpacked, lengths):
# Index of the last output for each sequence.
idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
unpacked.size(2)).unsqueeze(1)
return unpacked.gather(1, idx).squeeze()
def forward(self, x, lengths):
embs = self.embedding(x)
# pack the batch
packed = pack_padded_sequence(embs, list(lengths.data),
batch_first=True)
out_packed, (h, c) = self.rnn(packed)
out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)
# get the outputs from the last *non-masked* timestep for each sentence
last_outputs = self.last_timestep(out_unpacked, lengths)
# project to the classes using a linear layer
logits = self.linear(last_outputs)
return logits
Question 2 - Masked Cross Entropy Loss
Yes, by default the zero padded timesteps (targets) matter. However, it is very easy to mask them. You have two options, depending on the version of PyTorch that you use.
PyTorch 0.2.0: Now pytorch supports masking directly in the CrossEntropyLoss, with the ignore_index argument. For example, in language modeling or seq2seq, where i add zero padding, i mask the zero padded words (target) simply like this:
loss_function = nn.CrossEntropyLoss(ignore_index=0)
PyTorch 0.1.12 and older: In the older versions of PyTorch, masking was not supported, so you had to implement your own workaround. I solution that i used, was masked_cross_entropy.py, by jihunchoi. You may be also interested in this discussion.
A few days ago, I found this method which uses indexing to accomplish the same task with a one-liner.
I have my dataset batch first ([batch size, sequence length, features]), so for me:
unpacked_out = unpacked_out[np.arange(unpacked_out.shape[0]), lengths - 1, :]
where unpacked_out is the output of torch.nn.utils.rnn.pad_packed_sequence.
I have compared it with the method described here, which looks similar to the last_timestep() method Christos Baziotis is using above (also recommended here), and the results are the same in my case.

Purpose of 'givens' variables in Theano.function

I was reading the code for the logistic function given at http://deeplearning.net/tutorial/logreg.html. I am confused about the difference between inputs & givens variables for a function. The functions that compute mistakes made by a model on a minibatch are:
test_model = theano.function(inputs=[index],
outputs=classifier.errors(y),
givens={
x: test_set_x[index * batch_size: (index + 1) * batch_size],
y: test_set_y[index * batch_size: (index + 1) * batch_size]})
validate_model = theano.function(inputs=[index],
outputs=classifier.errors(y),
givens={
x: valid_set_x[index * batch_size:(index + 1) * batch_size],
y: valid_set_y[index * batch_size:(index + 1) * batch_size]})
Why couldn't/wouldn't one just make x& y shared input variables and let them be defined when an actual model instance is created?
The givens parameter allows you to separate the description of the model and the exact definition of the inputs variable. This is a consequence of what the given parameter do: modify the graph to compile before compiling it. In other words, we substitute in the graph, the key in givens with the associated value.
In the deep learning tutorial, we use a normal Theano variable to build the model. We use givens to speed up the GPU. Here, if we keep the dataset on the CPU, we will transfer a mini-batch to the GPU at each function call. As we do many iterations on the dataset, we end up transferring the dataset multiple time to the GPU. As the dataset is small enough to fit on the GPU, we put it in a shared variable to have it transferred to the GPU if one is available (or stay on the Central Processing Unit if the Graphics Processing Unit is disabled). Then when compiling the function, we swap the input with a slice corresponding to the mini-batch of the dataset to use. Then the input of the Theano function is just the index of that mini-batch we want to use.
I don't think anything is stopping you from doing it that way (I didn't try the updates= dictionary using an input variable directly, but why not). Remark however that for pushing data to a GPU in a useful manner, you will need it to be in a shared variable (from which x and y are taken in this example).

Resources