Only backpropagate up to a given variable - pytorch

I’m working on a GAN where the discriminator operates on the latent space vector produced by the encoder. The details aren’t important, but the paper describing the model is https://arxiv.org/abs/1706.00409 if you want to take a look.
Essentially, my problem is that my training code requires an unnecessary backward pass through the encoder and I’m not sure how to get around this such that it becomes optimal. Here is the relevant code, where E is the encoder and D is the discriminator.
latent_vec = E(input) #latent_vector will be a Variable with requires_grad=True
predictions = D(latent_vec)
e_loss = encoder_loss(predictions, ground_truth)
e_optimizer.zero_grad()
e_loss.backward() #backpropagates through both D and E, which is necessary
e_optimizer.step()
d_loss = discriminator_loss(predictions, ground_truth)
d_optimizer.zero_grad()
d_loss.backward() #backpropagates through D, but also through E unnecessarily
d_optimizer.step()
This still works because the optimizers are only modifying the parameters of their respective models, but it’s inefficient because d_loss.backward() unnecessarily backpropagates through E. I’m aware that I can recreate a version of latent_vec that will prevent backpropagation through E using
latent_vec_no_grad = Variable(latent_vec.data), but then I would be stuck with 2 forward passes through D (one with latent_vec so that e_loss backpropagates through E, and one with latent_vec_no_grad so that d_loss only backpropagates through D).
Ideally, a flag such as latent_vec.block_backprop = True could be set after optimizing E, but no such flag exists. Is there an elegant solution that would make training optimal?

Related

Is vmap efficient as compared to batched ops?

I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
is vmap the correct/expected way to batch train models in (at least most of the time)?
is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.
vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.
If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M # v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.
So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.

How to get a 2D output from linear layer in pytorch?

I would like to project a tensor into a space with an additional dimension.
I tried
torch.nn.Linear(
in_features=num_inputs,
out_features=(num_inputs, num_additional),
)
But this results in an error
A workaround would be to
torch.nn.Linear(
in_features=num_inputs,
out_features=num_inputs*num_additional,
)
and then change the view the output
output.view(batch_size, num_inputs, num_additional)
But I imagine this workaround will get tricky to read, especially when a projection into more than one additional dimension is desired.
Is there a more direct way to code this operation?
Perhaps the source code for linear can be changed
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
To accept more dimensions for the weight and bias initialization, and F.linear seems like it would need to be replaced with a different function.
IMO the workaround you provided is already clear enough. However, if you want to express this as a single operation, you can always write your own module by subclassing torch.nn.Linear:
import numpy as np
import torch
class MultiDimLinear(torch.nn.Linear):
def __init__(self, in_features, out_shape, **kwargs):
self.out_shape = out_shape
out_features = np.prod(out_shape)
super().__init__(in_features, out_features, **kwargs)
def forward(self, x):
out = super().forward(x)
return out.reshape((len(x), *self.out_shape))
if __name__ == '__main__':
tmp = torch.empty((32, 10))
linear = MultiDimLinear(in_features=10, out_shape=(10, 10))
out = linear(tmp)
print(out.shape) # (32, 10, 10)
Another way would be to use torch.einsum
https://pytorch.org/docs/stable/generated/torch.einsum.html
torch.einsum can prevent summation across dimensions in tensor to tensor multiplication operations. This can allow separate multiplication operations to happen in parallel. [ I do not know if this would necessarily result in GPU efficiency; if the operations are still occurring in the same kernel. In fact, it may be slower https://github.com/pytorch/pytorch/issues/32591 ]
How this would work is to directly initialize the weight and bias tensors (look at source code for the torch linear layer for that code)
Say that the input (X) has dimensions (a, b), where a is the batch size.
Say that you want to pass this input through a series of classifiers, represented in a single weight tensor (W) with dimensions (c, d, e), where c is the number of classifiers, and e is the number of classes for the classifier
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 2)
torch.einsum('ab, cbe -> ace', x, w)
in the last line, a and b are the dimensions of the input as mentioned above. What might be the tricky part is c, b, and e are the dimensions of the classifiers weight tensor; I didn't use d, I used b instead. That is because the vector multiplication is happening along that dimension for the inputs tensor and the weight tensor. So that's why the left side of the einsum equation is ab, cbe. The right side of the einsum equation is simply what dimensions to exclude from summation.
The final dimensions we want is (a, c, e). a is the batch size, c is the number of classifiers, and e is the number of classes for each classifier. We do not want to add those values, so to preserve their separation, the left side of the equation is ace.
For those unfamiliar with einsum, this will be harder to read than the word around I created (though I highly recommend learning it, because it gets very easy and intuitive very fast even though it's a bit tricky at first https://www.youtube.com/watch?v=pkVwUVEHmfI )
However, for paralyzing certain operations (especially on GPU), it seems that einsum is the only way to do it. For example so that in my previous example, I didn't want to use a classification head yet, I just wanted to project to multiple dimensions.
import torch
x = torch.arange(2*4).view(2, 4)
w = torch.arange(5*4*6).view(5, 4, 4)
y = torch.einsum('ab, cbe -> ace', x, w)
And say I do a few other operations to y, perhaps some non linear operations, activations, etc.
z = f(y)
z will still have the dimensions 2, 5, 4. Batch size two, 5 hidden states per batch, and the dimension of those hidden states are 4.
And then I want to apply a classifier to each separate tensor.
w2 = torch.arange(4*2).view(4, 2)
final = torch.einsum('fgh, hj -> fgj', z, w2)
Quick refresh, 2 is the batch size, 5 is the number of classifier, and 2 is the number of outputs for each classifier.
The output dimensions, f, g, j (2, 5, 2) will not be summed across, and thus will be preserved in the output.
As cited in the github link, this may be slower than just using regular linear layers. There may be efficiencies in a very large number of parallel operations.

Multiclass semantic segmentation model evaluation

I am doing a project on multiclass semantic segmentation. I have formulated a model that outputs pretty descent segmented images by decreasing the loss value. However, I cannot evaluate the model performance in metrics, such as meanIoU or Dice coefficient.
In case of binary semantic segmentation it was easy just to set the threshold of 0.5, to classify the outputs as an object or background, but it does not work in the case of multiclass semantic segmentation. Could you please tell me how to obtain model performance on the aforementioned metrics? Any help will be highly appreciated!
By the way, I am using PyTorch framework and CamVid dataset.
If anyone is interested in this answer, please also look at this issue. The author of the issue points out that mIoU can be computed in a different way (and that method is more accepted in literature). So, consider that before using the implementation for any formal publication.
Basically, the other method suggested by the issue-poster is to separately accumulate the intersections and unions over the entire dataset and divide them at the final step. The method in the below original answer computes intersection and union for a batch of images, then divides them to get IoU for the current batch, and then takes a mean of the IoUs over the entire dataset.
However, this below given original method is problematic because the final mean IoU would vary with the batch-size. On the other hand, the mIoU would not vary with the batch size for the method mentioned in the issue as the separate accumulation would ensure that batch size is irrelevant (though higher batch size can definitely help speed up the evaluation).
Original answer:
Given below is an implementation of mean IoU (Intersection over Union) in PyTorch.
def mIOU(label, pred, num_classes=19):
pred = F.softmax(pred, dim=1)
pred = torch.argmax(pred, dim=1).squeeze(1)
iou_list = list()
present_iou_list = list()
pred = pred.view(-1)
label = label.view(-1)
# Note: Following for loop goes from 0 to (num_classes-1)
# and ignore_index is num_classes, thus ignore_index is
# not considered in computation of IoU.
for sem_class in range(num_classes):
pred_inds = (pred == sem_class)
target_inds = (label == sem_class)
if target_inds.long().sum().item() == 0:
iou_now = float('nan')
else:
intersection_now = (pred_inds[target_inds]).long().sum().item()
union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
iou_now = float(intersection_now) / float(union_now)
present_iou_list.append(iou_now)
iou_list.append(iou_now)
return np.mean(present_iou_list)
Prediction of your model will be in one-hot form, so first take softmax (if your model doesn't already) followed by argmax to get the index with the highest probability at each pixel. Then, we calculate IoU for each class (and take the mean over it at the end).
We can reshape both the prediction and the label as 1-D vectors (I read that it makes the computation faster). For each class, we first identify the indices of that class using pred_inds = (pred == sem_class) and target_inds = (label == sem_class). The resulting pred_inds and target_inds will have 1 at pixels labelled as that particular class while 0 for any other class.
Then, there is a possibility that the target does not contain that particular class at all. This will make that class's IoU calculation invalid as it is not present in the target. So, you assign such classes a NaN IoU (so you can identify them later) and not involve them in the calculation of the mean.
If the particular class is present in the target, then pred_inds[target_inds] will give a vector of 1s and 0s where indices with 1 are those where prediction and target are equal and zero otherwise. Taking the sum of all elements of this will give us the intersection.
If we add all the elements of pred_inds and target_inds, we'll get the union + intersection of pixels of that particular class. So, we subtract the already calculated intersection to get the union. Then, we can divide the intersection and union to get the IoU of that particular class and add it to a list of valid IoUs.
At the end, you take the mean of the entire list to get the mIoU. If you want the Dice Coefficient, you can calculate it in a similar fashion.

Gradient flow through torch.nn.Parameter()

I have a toy example
a = torch.ones(10)
b = torch.nn.Parameter(a,requires_grad=True)
c = (b**2).sum()
c.backward()
print(b.grad)
print(a.grad)
b.grad calculated successfully, but a.grad is None. How to make gradient flow through torch.nn.Parameter? This example looks artificial, but I work with class A derived from nn.Module and it's parameters initialized with outputs from some other Module B, and I whant to make gradients flow through A parameters to B parameters.
#a_guest answer is wrong. Using requires_grad=True here will change nothing since torch.nn.Parameter is not tracked in computation graph. You should do it the other way around, to create a Parameter tensor, and then to extract a raw tensor reference out of it:
:
a = torch.nn.Parameter(torch.ones((10,)), requires_grad=True)
b = a[:] # silly hack to convert in a raw tensor including the computation graph
b.retain_grad() # Otherwise backward pass will not store the gradient since it is not a leaf
c = (b**2).sum()
c.backward()
print(b.grad)
print(a.grad)
Another approach would be to copy manually the content of tensor a in b
You could fix this by making the copy explicit:
a = torch.ones((10,), requires_grad=True)
b = torch.nn.Parameter(a.clone(), requires_grad=True)
b = a
c = (b**2).sum()
c.backward()
print(b.grad)
print(a.grad)
Yet it is not very convenient since the copy must be done systematically.

How to correctly implement a batch-input LSTM network in PyTorch?

This release of PyTorch seems provide the PackedSequence for variable lengths of input for recurrent neural network. However, I found it's a bit hard to use it correctly.
Using pad_packed_sequence to recover an output of a RNN layer which were fed by pack_padded_sequence, we got a T x B x N tensor outputs where T is the max time steps, B is the batch size and N is the hidden size. I found that for short sequences in the batch, the subsequent output will be all zeros.
Here are my questions.
For a single output task where the one would need the last output of all the sequences, simple outputs[-1] will give a wrong result since this tensor contains lots of zeros for short sequences. One will need to construct indices by sequence lengths to fetch the individual last output for all the sequences. Is there more simple way to do that?
For a multiple output task (e.g. seq2seq), usually one will add a linear layer N x O and reshape the batch outputs T x B x O into TB x O and compute the cross entropy loss with the true targets TB (usually integers in language model). In this situation, do these zeros in batch output matters?
Question 1 - Last Timestep
This is the code that i use to get the output of the last timestep. I don't know if there is a simpler solution. If it is, i'd like to know it. I followed this discussion and grabbed the relative code snippet for my last_timestep method. This is my forward.
class BaselineRNN(nn.Module):
def __init__(self, **kwargs):
...
def last_timestep(self, unpacked, lengths):
# Index of the last output for each sequence.
idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
unpacked.size(2)).unsqueeze(1)
return unpacked.gather(1, idx).squeeze()
def forward(self, x, lengths):
embs = self.embedding(x)
# pack the batch
packed = pack_padded_sequence(embs, list(lengths.data),
batch_first=True)
out_packed, (h, c) = self.rnn(packed)
out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)
# get the outputs from the last *non-masked* timestep for each sentence
last_outputs = self.last_timestep(out_unpacked, lengths)
# project to the classes using a linear layer
logits = self.linear(last_outputs)
return logits
Question 2 - Masked Cross Entropy Loss
Yes, by default the zero padded timesteps (targets) matter. However, it is very easy to mask them. You have two options, depending on the version of PyTorch that you use.
PyTorch 0.2.0: Now pytorch supports masking directly in the CrossEntropyLoss, with the ignore_index argument. For example, in language modeling or seq2seq, where i add zero padding, i mask the zero padded words (target) simply like this:
loss_function = nn.CrossEntropyLoss(ignore_index=0)
PyTorch 0.1.12 and older: In the older versions of PyTorch, masking was not supported, so you had to implement your own workaround. I solution that i used, was masked_cross_entropy.py, by jihunchoi. You may be also interested in this discussion.
A few days ago, I found this method which uses indexing to accomplish the same task with a one-liner.
I have my dataset batch first ([batch size, sequence length, features]), so for me:
unpacked_out = unpacked_out[np.arange(unpacked_out.shape[0]), lengths - 1, :]
where unpacked_out is the output of torch.nn.utils.rnn.pad_packed_sequence.
I have compared it with the method described here, which looks similar to the last_timestep() method Christos Baziotis is using above (also recommended here), and the results are the same in my case.

Resources