Is there a element-map function in pytorch? - pytorch

I'm new in PyTorch and I come from functional programming languages(where map function is used everywhere). The problem is that I have a tensor and I want to do some operations on each element of the tensor. The operation may be various so I need a function like this:
map : (Numeric -> Numeric) -> Tensor -> Tensor
e.g. map(lambda x: x if x < 255 else -1, tensor) # the example is simple, but the lambda may be very complex
Is there such a function in PyTorch? How should I implement such function?

Most mathematical operations that are implemented for tensors (and similarly for ndarrays in numpy) are actually applied element wise, so you could write for instance
mask = tensor < 255
result = tensor * mask + (-1) * ~mask
This is a quite general appraoch. For the case that you have right now where you only want to modify certain elements, you can also apply "logical indexing" that let's you overwrite the current tensor:
tensor[mask < 255] = -1
So in python there actually is a map() function but usually there are better ways to do it (better in python; in other languages - like Haskell - map/fmap is obviously prefered in most contexts).
So the key take-away here is that the preferred method is taking advantage of the vectorization. This also makes the code more efficient as those tensor operations are implemented in a low level language, while map() is nothing but a python-for loop that is a lot slower.

Related

Is vmap efficient as compared to batched ops?

I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
is vmap the correct/expected way to batch train models in (at least most of the time)?
is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.
vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.
If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M # v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.
So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.

How to improve this toy Jax optimizer code with while loops and saved history?

I'm writing a custom optimizer I want JIT-able with Jax which features 1) breaking on maximum steps reached 2) breaking on a tolerance reached, and 3) saving the history of the steps taken. I'm relatively new to some of this stuff in Jax, but reading the docs I have this solution:
import jax, jax.numpy as jnp
#jax.jit
def optimizer(x, tol = 1, max_steps = 5):
def cond(arg):
step, x, history = arg
return (step < max_steps) & (x > tol)
def body(arg):
step, x, history = arg
x = x / 2 # simulate taking an optimizer step
history = history.at[step].set(x) # simulate saving current step
return (step + 1, x, history)
return jax.lax.while_loop(
cond,
body,
(0, x, jnp.full(max_steps, jnp.nan))
)
optimizer(10.) # works
My question is whether this can be improved in some way? In particular, is there a way to avoid pre-allocating the history? This isn't ideal since the real thing is alot more complicated than a single array and there's obviously the potential for wasted memory if tolerance is reached well before the maximum steps.
is there a way to avoid pre-allocating the history?
No, as I understand JAX
in JAX, 'type' includes shape, that is, the in and out data shape of the body function MUST be the same, otherwise, say dynamic grow history use jnp.vstack((history, x)), JAX will consider it as side effect.
There is a way, if you think the tolerance will be often reached before the maximum number of steps.
JAX implements sparse matrices (and pytrees of them), in jax.experimental.sparse. They will have the same shape as the maximum history size, and therefore satisfy the "fixed size" requirement for XLA, but of course, will only store nonzero elements in memory.

How to get a 2D output from linear layer in pytorch?

I would like to project a tensor into a space with an additional dimension.
I tried
torch.nn.Linear(
in_features=num_inputs,
out_features=(num_inputs, num_additional),
)
But this results in an error
A workaround would be to
torch.nn.Linear(
in_features=num_inputs,
out_features=num_inputs*num_additional,
)
and then change the view the output
output.view(batch_size, num_inputs, num_additional)
But I imagine this workaround will get tricky to read, especially when a projection into more than one additional dimension is desired.
Is there a more direct way to code this operation?
Perhaps the source code for linear can be changed
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
To accept more dimensions for the weight and bias initialization, and F.linear seems like it would need to be replaced with a different function.
IMO the workaround you provided is already clear enough. However, if you want to express this as a single operation, you can always write your own module by subclassing torch.nn.Linear:
import numpy as np
import torch
class MultiDimLinear(torch.nn.Linear):
def __init__(self, in_features, out_shape, **kwargs):
self.out_shape = out_shape
out_features = np.prod(out_shape)
super().__init__(in_features, out_features, **kwargs)
def forward(self, x):
out = super().forward(x)
return out.reshape((len(x), *self.out_shape))
if __name__ == '__main__':
tmp = torch.empty((32, 10))
linear = MultiDimLinear(in_features=10, out_shape=(10, 10))
out = linear(tmp)
print(out.shape) # (32, 10, 10)
Another way would be to use torch.einsum
https://pytorch.org/docs/stable/generated/torch.einsum.html
torch.einsum can prevent summation across dimensions in tensor to tensor multiplication operations. This can allow separate multiplication operations to happen in parallel. [ I do not know if this would necessarily result in GPU efficiency; if the operations are still occurring in the same kernel. In fact, it may be slower https://github.com/pytorch/pytorch/issues/32591 ]
How this would work is to directly initialize the weight and bias tensors (look at source code for the torch linear layer for that code)
Say that the input (X) has dimensions (a, b), where a is the batch size.
Say that you want to pass this input through a series of classifiers, represented in a single weight tensor (W) with dimensions (c, d, e), where c is the number of classifiers, and e is the number of classes for the classifier
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 2)
torch.einsum('ab, cbe -> ace', x, w)
in the last line, a and b are the dimensions of the input as mentioned above. What might be the tricky part is c, b, and e are the dimensions of the classifiers weight tensor; I didn't use d, I used b instead. That is because the vector multiplication is happening along that dimension for the inputs tensor and the weight tensor. So that's why the left side of the einsum equation is ab, cbe. The right side of the einsum equation is simply what dimensions to exclude from summation.
The final dimensions we want is (a, c, e). a is the batch size, c is the number of classifiers, and e is the number of classes for each classifier. We do not want to add those values, so to preserve their separation, the left side of the equation is ace.
For those unfamiliar with einsum, this will be harder to read than the word around I created (though I highly recommend learning it, because it gets very easy and intuitive very fast even though it's a bit tricky at first https://www.youtube.com/watch?v=pkVwUVEHmfI )
However, for paralyzing certain operations (especially on GPU), it seems that einsum is the only way to do it. For example so that in my previous example, I didn't want to use a classification head yet, I just wanted to project to multiple dimensions.
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 4)
y = torch.einsum('ab, cbe -> ace', x, w)
And say I do a few other operations to y, perhaps some non linear operations, activations, etc.
z = f(y)
z will still have the dimensions 2, 5, 4. Batch size two, 5 hidden states per batch, and the dimension of those hidden states are 4.
And then I want to apply a classifier to each separate tensor.
w2 = torch.arange(4*2).view(4, 2)
final = torch.einsum('fgh, hj -> fgj', z, w2)
Quick refresh, 2 is the batch size, 5 is the number of classifier, and 2 is the number of outputs for each classifier.
The output dimensions, f, g, j (2, 5, 2) will not be summed across, and thus will be preserved in the output.
As cited in the github link, this may be slower than just using regular linear layers. There may be efficiencies in a very large number of parallel operations.

matrix multiplication for complex numbers in PyTorch

I am trying to multiply two complex matrices in PyTorch and it seems the torch.matmul functions is not added yet to PyTorch library for complex numbers.
Do you have any recommendation or is there another method to multiply complex matrices in PyTorch?
Currently torch.matmul is not supported for complex tensors such as ComplexFloatTensor but you could do something as compact as the following code:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real # t2.real - t1.imag # t2.imag, t1.real # t2.imag + t1.imag # t2.real),dim=2))
When possible avoid using for loops as these will result in much slower implementations.
Vectorization is achieved by using built-in methods as demonstrated in the code I have attached.
For example, your code takes roughly 6.1s on CPU while the vectorized version takes only 101ms (~60 times faster) for 2 random complex matrices with dimensions 1000 X 1000.
Update:
Since PyTorch 1.7.0 (as #EduardoReis mentioned) you can do matrix multiplication between complex matrices similarly to real-valued matrices as follows:
t1 # t2
(for t1, t2 complex matrices).
I implemented this function for pytorch.matmul for complex numbers using torch.mv and it's working fine for time-being:
def matmul_complex(t1, t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n), dtype=torch.cfloat)
t_total = torch.empty((m,n), dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
t_final = torch.reshape(t_total, (m,n))
return t_final
I am new to PyTorch, so please correct me if I am wrong.

Pytorch parameter matrix from loss function of transformation

I have a pytorch tensor k x (n+k-1) tensor w with requires_grad=True. I want to transform it into a kxn tensor p also with as such: p[i] = w[i][i:i+n]. How do I do this, such that by calling backward() on a loss function of p in the end, I will learn w?
Any sort of indexing operation would do, with the backward function being <CopySlices>
A naive way of doing this would be using simple python indexing:
w_unrolled = torch.zeros(p.size())
for i in range(w.shape[0]):
w_unrolled[i] = w[i][i:i+n]
loss = criterion(w_unrolled, p)
You can then reduce your loss via mean/sum on whichever axis. Note that while this will work, it is inefficient; the optimal way would be to use a native indexing function to speed things up.

Resources