Is vmap efficient as compared to batched ops? - jax

I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
is vmap the correct/expected way to batch train models in (at least most of the time)?
is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.

vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.

If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M # v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.
So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.

Related

How to improve this toy Jax optimizer code with while loops and saved history?

I'm writing a custom optimizer I want JIT-able with Jax which features 1) breaking on maximum steps reached 2) breaking on a tolerance reached, and 3) saving the history of the steps taken. I'm relatively new to some of this stuff in Jax, but reading the docs I have this solution:
import jax, jax.numpy as jnp
#jax.jit
def optimizer(x, tol = 1, max_steps = 5):
def cond(arg):
step, x, history = arg
return (step < max_steps) & (x > tol)
def body(arg):
step, x, history = arg
x = x / 2 # simulate taking an optimizer step
history = history.at[step].set(x) # simulate saving current step
return (step + 1, x, history)
return jax.lax.while_loop(
cond,
body,
(0, x, jnp.full(max_steps, jnp.nan))
)
optimizer(10.) # works
My question is whether this can be improved in some way? In particular, is there a way to avoid pre-allocating the history? This isn't ideal since the real thing is alot more complicated than a single array and there's obviously the potential for wasted memory if tolerance is reached well before the maximum steps.
is there a way to avoid pre-allocating the history?
No, as I understand JAX
in JAX, 'type' includes shape, that is, the in and out data shape of the body function MUST be the same, otherwise, say dynamic grow history use jnp.vstack((history, x)), JAX will consider it as side effect.
There is a way, if you think the tolerance will be often reached before the maximum number of steps.
JAX implements sparse matrices (and pytrees of them), in jax.experimental.sparse. They will have the same shape as the maximum history size, and therefore satisfy the "fixed size" requirement for XLA, but of course, will only store nonzero elements in memory.

Is there a element-map function in pytorch?

I'm new in PyTorch and I come from functional programming languages(where map function is used everywhere). The problem is that I have a tensor and I want to do some operations on each element of the tensor. The operation may be various so I need a function like this:
map : (Numeric -> Numeric) -> Tensor -> Tensor
e.g. map(lambda x: x if x < 255 else -1, tensor) # the example is simple, but the lambda may be very complex
Is there such a function in PyTorch? How should I implement such function?
Most mathematical operations that are implemented for tensors (and similarly for ndarrays in numpy) are actually applied element wise, so you could write for instance
mask = tensor < 255
result = tensor * mask + (-1) * ~mask
This is a quite general appraoch. For the case that you have right now where you only want to modify certain elements, you can also apply "logical indexing" that let's you overwrite the current tensor:
tensor[mask < 255] = -1
So in python there actually is a map() function but usually there are better ways to do it (better in python; in other languages - like Haskell - map/fmap is obviously prefered in most contexts).
So the key take-away here is that the preferred method is taking advantage of the vectorization. This also makes the code more efficient as those tensor operations are implemented in a low level language, while map() is nothing but a python-for loop that is a lot slower.

Unexpected solution using JiTCDDE

I'm trying to investigate the behavior of the following Delayed Differential Equation using Python:
y''(t) = -y(t)/τ^2 - 2y'(t)/τ - Nd*f(y(t-T))/τ^2,
where f is a cut-off function which is essentially equal to the identity when the absolute value of its argument is between 1 and 10 and otherwise is equal to 0 (see figure 1), and Nd, τ and T are constants.
For this I'm using the package JiTCDDE. This provides a reasonable solution to the above equation. Nevertheless, when I try to add a noise on the right hand side of the equation, I obtain a solution which stabilize to a non-zero constant after a few oscillations. This is not a mathematical solution of the equation (the only possible constant solution being equal to zero). I don't understand why this problem arises and if it is possible to solve it.
I reproduce my code below. Here, for the sake of simplicity, I substituted the noise with an high-frequency cosine, which is introduced in the system of equation as the initial condition for a dummy variable (the cosine could have been introduced directly in the system, but for a general noise this doesn't seem possible). To simplify further the problem, I removed also the term involving the f function, as the problem arises also without it. Figure 2 shows the plot of the function given by the code.
from jitcdde import jitcdde, y, t
import numpy as np
from matplotlib import pyplot as plt
import math
from chspy import CubicHermiteSpline
# Definition of function f:
def functionf(x):
return x/4*(1+symengine.erf(x**2-Bmin**2))*(1-symengine.erf(x**2-Bmax**2))
#parameters:
τ = 42.9
T = 35.33
Nd = 8.32
# Definition of the initial conditions:
dt = .01 # Time step.
totT = 10000. # Total time.
Nmax = int(totT / dt) # Number of time steps.
Vt = np.linspace(0., totT, Nmax) # Vector of times.
# Definition of the "noise"
X = np.zeros(Nmax)
for i in range(Nmax):
X[i]=math.cos(Vt[i])
past=CubicHermiteSpline(n=3)
for time, datum in zip(Vt,X):
regular_past = [10.,0.]
past.append((
time-totT,
np.hstack((regular_past,datum)),
np.zeros(3)
))
noise= lambda t: y(2,t-totT)
# Integration of the DDE
g = [
y(1),
-y(0)/τ**2-2*y(1)/τ+0.008*noise(t)
]
g.append(0)
DDE = jitcdde(g)
DDE.add_past_points(past)
DDE.adjust_diff()
data = []
for time in np.arange(DDE.t, DDE.t+totT, 1):
data.append( DDE.integrate(time)[0] )
plt.plot(data)
plt.show()
Incidentally, I noticed that even without noise, the solution seems to be discontinuous at the point zero (y is set to be equal to zero for negative times), and I don't understand why.
As the comments unveiled, your problem eventually boiled down to this:
step_on_discontinuities assumes delays that are small with respect to the integration time and performs steps that are placed on those times where the delayed components points to the integration start (0 in your case). This way initial discontinuities are handled.
However, implementing an input with a delayed dummy variable introduces a large delay into the system, totT in your case.
The respective step for step_on_discontinuities would be at totT itself, i.e., after the desired integration time.
Thus when you reach for time in np.arange(DDE.t, DDE.t+totT, 1): in your code, DDE.t is totT.
Therefore you have made a big step before you actually start integrating and observing which may seem like a discontinuity and lead to weird results, in particular you do not see the effect of your input, because it has already “ended” at this point.
To avoid this, use adjust_diff or integrate_blindly instead of step_on_discontinuities.

Element-wise variance of an iterator

What's a numerically-stable way of taking the variance of an iterator elementwise? As an example, I would like to do something like
var((rand(4,2) for i in 1:10))
and get back a (4,2) matrix which is the variance in each coefficient. This throws an error using Julia's Base var. Is there a package that can handle this? Or an easy (and storage-efficient) way to do this using the Base Julia function? Or does one need to be developed on its own?
I went ahead and implemented a Welford algorithm to calculate this:
# Welford algorithm
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
function componentwise_meanvar(A;bessel=true)
x0 = first(A)
n = 0
mean = zero(x0)
M2 = zero(x0)
delta = zero(x0)
delta2 = zero(x0)
for x in A
n += 1
delta .= x .- mean
mean .+= delta./n
delta2 .= x .- mean
M2 .+= delta.*delta2
end
if n < 2
return NaN
else
if bessel
M2 .= M2 ./ (n .- 1)
else
M2 .= M2 ./ n
end
return mean,M2
end
end
A few other algorithms are implemented in DiffEqMonteCarlo.jl as well. I'm surprised I couldn't find a library for this, but maybe will refactor this out someday.
See update below for a numerically stable version
Another method to calculate this:
srand(0) # reset random for comparing across implementations
moment2var(t) = (t[3]-t[2].^2./t[1])./(t[1]-1)
foldfunc(x,y) = (x[1]+1,x[2].+y,x[3].+y.^2)
moment2var(foldl(foldfunc,(0,zeros(1,1),zeros(1,1)),(rand(4,2) for i=1:10)))
Gives:
4×2 Array{Float64,2}:
0.0848123 0.0643537
0.0715945 0.0900416
0.111934 0.084314
0.0819135 0.0632765
Similar to:
srand(0) # reset random for comparing across implementations
# naive component-wise application of `var` function
map(var,zip((rand(4,2) for i=1:10)...))
which is the non-iterator version (or offline version in CS terminology).
This method is based on calculation of variance from mean and sum-of-squares. moment2var and foldfunc are just a helper functions, but it fits in one-line without them.
Comments:
Speedwise, this should be pretty good as well. Perhaps, StaticArrays and initializing the foldl's v0 with the correct eltype of the iterator would save even more time.
Benchmarking gave 5x speed advantage (and better memory usage) over componentwise_meanvar (from another answer) on a sample input.
Using moment2meanvar(t)=(t[2]./t[1],(t[3]-t[2].^2./t[1])./(t[1]-1)‌​) gives both mean and variance like componentwise_meanvar.
As #ChrisRackauckas noted, this method suffers from numerical instability when number of elements to sum is large.
--- UPDATE with variant of method ---
A little abstraction of the question asks for a way to do a foldl (and reduce,foldr) on an iterator returning a matrix, element-wise and retaining shape. To do so, we can define an assisting function mfold which takes a folding-function and makes it fold matrices element-wise. Define it as follows:
mfold(f) = (x,y)->[f(t[1],t[2]) for t in zip(x,y)]
For this specific problem of variance, we can define the component-wise fold functions, and a final function to combine the moments into the variance (and mean if wanted). The code:
ff(x,y) = (x[1]+1,x[2]+y,x[3]+y^2) # fold and collect moments
moment2var(t) = (t[3]-t[2]^2/t[1])/(t[1]-1) # calc variance from moments
moment2meanvar(t) = (t[2]./t[1],(t[3]-t[2].^2./t[1])./(t[1]-1))
We can see moment2meanvar works on a single vector as follows:
julia> moment2meanvar(foldl(ff,(0.0,0.0,0.0),[1.0,2.0,3.0]))
(2.0, 1.0)
Now to matrix-ize it using foldm (using .-notation):
moment2var.(foldl(mfold(ff),fill((0,0,0),(4,2)),(rand(4,2) for i=1:10)))
#ChrisRackauckas noted this is not numerically stable, and another method (detailed in Wikipedia) is better. Using foldm this could be implemented as:
# better fold function compensating the sums for stability
ff2(x,y) = begin
delta=y-x[2]
mean=x[2]+delta/(x[1]+1)
return (x[1]+1,mean,x[3]+delta*(y-mean))
end
# combine the collected information for the variance (and mean)
m2var(t) = t[3]/(t[1]-1)
m2meanvar(t) = (t[2],t[3]/(t[1]-1))
Again we have:
m2var.(foldl(mfold(ff2),fill((0,0.0,0.0),(4,2)),(rand(4,2) for i=1:10)))
Giving the same results (perhaps a little more accurately).
Or an easy (and storage-efficient) way to do this using the Base Julia function?
Out of curiosity, why is the standard solution of using var along the external dimension not good for you?
julia> var(cat(3,(rand(4,2) for i in 1:10)...),3)
4×2×1 Array{Float64,3}:
[:, :, 1] =
0.08847 0.104799
0.0946243 0.0879721
0.105404 0.0617594
0.0762611 0.091195
Obviously, I'm using cat here, which clearly is not very storage efficient, just so I can use the Base Julia function and your original generator syntax as per your question. But you could make this storage efficient as well, if you initialise your random values directly on a preallocated array of size (4,2,10), so that's not really an issue here.
Or did I misunderstand your question?
EDIT - benchmark in response to comments
function standard_var(Y, A)
for i in 1 : length(A)
Y[:,:,i], = next(A,i);
end
var(Y,3)
end
function testit()
A = (rand(4,2) for i in 1:10000);
Y = Array{Float64, 3}(4,2,length(A));
#time componentwise_meanvar(A); # as defined in Chris's answer above
#time standard_var(Y, A) # standard variance + using preallocation
#time var(cat(3, A...), 3); # standard variance without preallocation
return nothing
end
julia> testit()
0.004258 seconds (10.01 k allocations: 1.374 MiB)
0.006368 seconds (49.51 k allocations: 2.129 MiB)
5.954470 seconds (50.19 M allocations: 2.989 GiB, 71.32% gc time)

(Incremental)PCA's Eigenvectors are not transposed but should be?

When we posted a homework assignment about PCA we told the course participants to pick any way of calculating the eigenvectors they found. They found multiple ways: eig, eigh (our favorite was svd). In a later task we told them to use the PCAs from scikit-learn - and were surprised that the results differed a lot more than we expected.
I toyed around a bit and we posted an explanation to the participants that either solution was correct and probably just suffered from numerical instabilities in the algorithms. However, recently I picked that file up again during a discussion with a co-worker and we quickly figured out that there's an interesting subtle change to make to get all results to be almost equivalent: Transpose the eigenvectors obtained from the SVD (and thus from the PCAs).
A bit of code to show this:
def pca_eig(data):
"""Uses numpy.linalg.eig to calculate the PCA."""
data = data.T # data
val, vec = np.linalg.eig(data)
return val, vec
versus
def pca_svd(data):
"""Uses numpy.linalg.svd to calculate the PCA."""
u, s, v = np.linalg.svd(data)
return s ** 2, v
Does not yield the same result. Changing the return of pca_svd to s ** 2, v.T, however, works! It makes perfect sense following the definition by wikipedia: The SVD of X follows X=UΣWT where
the right singular vectors W of X are equivalent to the eigenvectors of XTX
So to get the eigenvectors we need to transposed the output v of np.linalg.eig(...).
Unless there is something else going on? Anyway, the PCA and IncrementalPCA both show wrong results (or eig is wrong? I mean, transposing that yields the same equality), and looking at the code for PCA reveals that they are doing it as I did it initially:
U, S, V = linalg.svd(X, full_matrices=False)
# flip eigenvectors' sign to enforce deterministic output
U, V = svd_flip(U, V)
components_ = V
I created a little gist demonstrating the differences (nbviewer), the first with PCA and IncPCA as they are (also no transposition of the SVD), the second with transposed eigenvectors:
Comparison without transposition of SVD/PCAs (normalized data)
Comparison with transposition of SVD/PCAs (normalized data)
As one can clearly see, in the upper image the results are not really great, while the lower image only differs in some signs, thus mirroring the results here and there.
Is this really wrong and a bug in scikit-learn? More likely I am using the math wrong – but what is right? Can you please help me?
If you look at the documentation, it's pretty clear from the shape that the eigenvectors are in the rows, not the columns.
The point of the sklearn PCA is that you can use the transform method to do the correct transformation.

Resources