I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
is vmap the correct/expected way to batch train models in (at least most of the time)?
is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.
vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.
If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M # v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.
So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.
Consider the case where I have a set of nodes, and I want to declare some ordering over them. The easiest way to do this is to declare both the set of nodes and their ranking as constants:
CONSTANTS Node, NodeRank
ASSUME NodeRank \in [Node -> Nat]
ASSUME \A n, m \in Node : NodeRank[n] = NodeRank[m] <=> n = m
Now it comes time to assign model values to these constants. Node is easy enough, I just define it as a set of model values in the toolbox:
Node <- [ model value ] {n1, n2}
I try to do something similar with NodeRank with ordinary assignment:
NodeRank <- [n1 |-> 1, n2 |-> 2]
However, when I run TLC the ASSUME statements are violated. Further examination reveals this is because in the ordinary assignment of NodeRank, n1 and n2 are treated as strings instead of model values. This makes sense, because that is the usual method of defining records (which use strings as their domain). How do I define NodeRank such that it uses the n1 and n2 model values as its domain?
If you extend TLC, you can write this as n1 :> 1 ## n2 :> 2.
I have a use case where I want to count types of elements in an RDD matching some filter.
e.g. RDD.filter(F1) and RDD.filter(!F1)
I have 2 options
Use accumulators: e.g.
LongAccumulator l1 = sparkContext.longAccumulator("Count1")
LongAccumulator l2 = sparkContext.longAccumulator("Count2")
RDD.forEachPartition(f -> {
if(F1) l1.add(1)
else l2.add(1)
});
Use Count
RDD.filter(F1).count(); RDD.filter(!F1).count()
One benefit of the first approach is that we only need to iterate data once (useful since my data set is 10s of TB)
What is the use of count if same affect can be achieved by using Accumulators ?
Major difference is that if your code will fail in transformation, then Accumulators will be updated and count() result not.
Other option is to use pure map-reduce:
val counts = rdd.map(x => (F1(x), 1)).reduceByKey(_ + _).collectAsMap()
Network cost should be also low as only few numbers will be sent. It creates pairs of (is F1(x) true/false, 1) and then sum all ones - it will give you number of items both F1(x) and !F1(x) in counts map
I'm curious as to when evaluation sets in, apparently certain operators are rather transformed into clauses than evaluated:
abstract sig Element {}
one sig A,B,C extend Element {}
one sig Test {
test: set Element
}
pred Test1 { Test.test = A+B }
pred Test2 { Test.test = Element-C }
and run it for Test1 and Test2 respectively will give different number of vars/clauses, specifically:
Test1: 0 vars, 0 primary vars, 0 clauses
Test2: 5 vars, 3 primary vars, 4 clauses
So although Element is abstract and all its members and their cardinalities are known, the difference seems not to be computed in advance, while the sum is. I don't want to make any assumptions, so I'm interested in why that is. Is the + operator special?
To give some context, I tried to limit the domain of a relation and found, that using only + seems to be more efficient, even when the sets are completely known in advance.
To give some context, I tried to limit the domain of a relation and found, that using only + seems to be more efficient, even when the sets are completely known in advance.
That is pretty much the right conclusion. The reason is the fact that the Alloy Analyzer tries to infer relation bounds from certain Alloy idioms. It uses a conservative approximation that is always sound for set union and product, but not for set difference. That's why for Test1 in the example above the Alloy Analyzer infers a fixed bound for the test relation (this/Test.test: [[[A$0], [B$0]]]) so no solver needs to be invoked; for Test2, the bound for the test relation cannot be shrunk so is set to be the most permissive (this/Test.test: [[], [[A$0], [B$0], [C$0]]]), thus a solver needs to be invoked to find a solution satisfying the constraints given the bounds.
I'm working with noisy data in IDL, so I've been using STDDEV and robust_sigma
. There are papers on robust skewness and kurtosis, for instance [1] and [2], but are there implementations, as for standard deviation? (in IDL or maybe C?)
The documentation of http://idlastro.gsfc.nasa.gov/ftp/pro/robust/robust_sigma.pro states:
; OPTIONAL OUPTUT KEYWORD:
; GOODVEC = Vector of non-trimmed indices of the input vector
So one calls robust_sigma with an extra parameter that keeps track of the "good indices" in the data, those used to compute the robust_sigma, as opposed to those ignored in its computation.
good_indices = lonarr(width)
robo_2 = robust_sigma(data[*], GOODVEC=good_indices)
Then use (only) those good indices to compute the other moments.
robo_3 = skewness(data[good_indices])
robo_4 = kurtosis(data[good_indices])
No need for a special implementation.