I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please - jax

I am new to Jax.
I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using compositions of flax.linen.nn modules). Crucially, in addition to passing this code through the decoder (as is standard for a VAE), I also pass the code to an external function (the MuJoCo physics engine), which tries to assign it to a NumPy array. This unsurprisingly leads to the following error:
TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object...
Fundamentally, I need to pass a concrete numpy array to MuJoCo. How can I make my variable a NumPy array will still allowing my model to be implemented in a computationally efficient manner using abstract tracers wherever possible?
Here is a minimal working example of the problem I am facing - gym and mujoco (https://mujoco.org/) will need to be installed to run this I believe:
import jax
import jax.numpy as np
import numpy as onp
import gym
from jax import jit
# create an instance of an open AI gym environment
env = gym.make('Humanoid-v3')
env.reset()
def this_fails(env, x):
# this gives a TracerArrayConversionError
env.sim.data.qpos[:] = x
return env, x
x = np.arange(len(env.sim.data.qpos))
jit_this_fails = jax.jit(this_fails, static_argnums = 0)
env, x = jit_this_fails(env, x)

Edit: there is now a JAX FAQ entry on this topic: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array
Note: this is the answer to the OP's question as originally written. The question has been edited multiple times and no longer asks what it originally asked.
In the past this sort of thing has not been supported, but you can do this with the new jax.pure_callback feature that is part of JAX version 0.3.17, which is not yet released at the time I am writing this.
For example, say you want to call a numpy-based function from within a JAX jit-compiled function; we'll use np.sin for simplicity. You might first try something like this:
import jax
import jax.numpy as jnp
import numpy as np
#jax.jit
def this_fails(x):
# Call a numpy function...
return np.sin(x)
x = jnp.arange(5.0)
this_fails(x)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function this_fails at tmp.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The result is a TracerConversionError, because you're attempting to pass a traced JAX value into a function that expects a numpy array (side note: see How To Think In JAX for an introduction to JAX Tracers and related topics).
In JAX version 0.3.17 or newer, you can get around this issue using jax.pure_callback:
#jax.jit
def numpy_callback(x):
# Need to forward-declare the shape & dtype of the expected output.
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
return jax.pure_callback(np.sin, result_shape, x)
x = jnp.arange(5.0)
print(numpy_callback(x))
[ 0. 0.841471 0.9092974 0.14112 -0.7568025]
A few caveats to keep in mind:
the resulting execution will rely on a callback to the host, so it will be quite slow on accelerators like GPU/TPU, particularly in distributed/multi-host settings. In the case of local CPU execution, though, it avoids buffer copies and can be quite performant.
if you vmap the function, it will result in a for loop of multiple callbacks (you can specify vectorized=True if the callback function handles batches natively).
autodiff transformations like grad and jacobian will not work with this function, because JAX has no way of reasoning about the computations being done. If you would like to use it with autodiff transformations, you could define custom gradients as in Custom Derivative Rules, though this would require having access to a function that computes the gradient for your callback function.
None of this is documented yet on the JAX website, but we hope to write docs for pure_callback soon!

Related

Best way to feed data to tflite interpreter in ARM linux

I have a Python application running on an ARM Linux device. The application continuously acquires images and feeds them to a Tensorflow lite model. This is the relevant portion of the code:
import tflite_runtime.interpreter as tflite
import numpy as np
interpreter = tflite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
while True:
img = acquire_img() # 3D numpy array
img = np.expand_dims(img, 0)
interpreter.set_tensor(input_details[0]['index'], img)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
I was reading the python API documentation in order to check for possible optimizations. The description of the set_tensor function is:
Sets the value of the input tensor. Note this copies data in value. If
you want to avoid copying, you can use the tensor() function to get a
numpy buffer pointing to the input buffer in the tflite interpreter.
The description of the tensor function is:
Returns function that gives a numpy view of the current tensor buffer.
This allows reading and writing to this tensors w/o copies. This more
closely mirrors the C++ Interpreter class interface's tensor() member,
hence the name. Be careful to not hold these output references through
calls to allocate_tensors() and invoke(). This function cannot be used
to read intermediate results.
Notice how this function avoids making a numpy array directly. This is
because it is important to not hold actual numpy views to the data
longer than necessary. If you do, then the interpreter can no longer
be invoked, because it is possible the interpreter would resize and
invalidate the referenced tensors. The NumPy API doesn't allow any
mutability of the the underlying buffers.
Returns A function that can return a new numpy array pointing to the
internal TFLite tensor state at any point. It is safe to hold the
function forever, but it is not safe to hold the numpy array forever.
I am a bit confused: could I have some advantages in switching from set_tensor() to tensor()? Is it feasible in my case?

How do I make ray.tune.run reproducible?

I'm using Tune class-based Trainable API. See code sample:
from ray import tune
import numpy as np
np.random.seed(42)
# first run
tune.run(tune.Trainable, ...)
# second run, expecting same result
np.random.seed(42)
tune.run(tune.Trainable, ...)
The problem is that tune.run results are still different, likely reason being that each ray actor still has different seed.
Question: how do I make ray.tune.run reproducible?
(This answer focuses on class API and ray version 0.8.7. Function API does not support reproducibility due to implementation specifics)
There are two main sources of undeterministic results.
1. Search algorithm
Every search algorithm supports random seed, although interface to it may vary. This initializes hyperparameter space sampling.
For example, if you're using AxSearch, it looks like this:
from ax.service.ax_client import AxClient
from ray.tune.suggest.ax import AxSearch
client = AxClient(..., random_seed=42)
client.create_experiment(...)
algo = AxSearch(client)
2. Trainable API
This is distributed among worker processes, which requires seeding within tune.Trainable class. Depending on the tune.Trainable.train logic that you implement, you need to manually seed numpy, tf, or whatever other framework you use, inside tune.Trainable.setup by passing seed with config argument of tune.run.
The following code is based on RLLib PR5197 that handled the same issue:
See the example:
from ray import tune
import numpy as np
import random
class Tuner(tune.Trainable):
def setup(self, config):
seed = config['seed']
np.random.seed(seed)
random.seed(seed)
...
...
seed = 42
tune.run(Tuner, config={'seed': seed})

What is the difference between parameters and children?

It looks like parameters and children show the same info, so what is the difference between them?
import torch
print('torch.__version__', torch.__version__)
m = torch.load('imagenet_resnet18.pth')
print(m.parameters)
print(m.children)
model.parameters() is a generator that returns tensors containing your model parameters.
model.children() is a generator that returns layers of the model from which you can extract your parameter tensors using <layername>.weight or <layername>.bias
Visit this link for a simple tutorial on accessing and freezing model layers.
The (only, during my writing) current answer is not to the point, and thus misleading in my own opinion. By the current docs(08/23/2022):
children():
Returns an iterator over immediate children modules.
This should mean that it will stop at non-leaf node like torch.nn.Sequential, torch.nn.ModuleList, etc.
parameters(recurse=True):
Returns an iterator over module parameters. This is typically passed to an optimizer.
"Passed to an optimizer" should imply that recursive cases are taken care by the team. Just pass the return value/object to the optimizer.
Since I know you're lazy developers, you must read this answer from PyTorch forum to see the output of children() done by someone: https://discuss.pytorch.org/t/module-children-vs-module-modules/4551/4?u=raining_day513

How to generate artificial sequential data for machine learning?

Sklearn provide different data generation functions such as make_blobs and make_regression in sklearn.datasets.
However, I am not aware of any functions that can generate sequential data. Is there any existing libraries that can generate artificial sequential data?
It really depends on what kind of series you want. Check out this repository for generating different kinds of simulated series. It's called TimeSynth
But if you just want something you can easily modify yourself, try writing a function similar to this:
def SynthSeries(start,end,stepSize,coefficients):
import numpy as np
samples = np.array(np.arange(start,end,stepSize))
array = np.array(np.zeros(np.shape(samples)))
for coeff in coefficients:
array = np.add(array,(np.sin(coeff*samples)))
return array, samples
This is sort of a reverse of a fourier transform, if you know the base frequencies of series you want to create, you can pass it into this function to recreate the signal.
You can use it like this:
import matplotlib.pyplot as plt
(SeqData,samples) = SynthSeries(0,20,0.1,[12,3,1,22])
plt.plot(samples, SeqData)
plt.show()

Custom loss function in PyTorch

I have three simple questions.
What will happen if my custom loss function is not differentiable? Will pytorch through error or do something else?
If I declare a loss variable in my custom function which will represent the final loss of the model, should I put requires_grad = True for that variable? or it doesn't matter? If it doesn't matter, then why?
I have seen people sometimes write a separate layer and compute the loss in the forward function. Which approach is preferable, writing a function or a layer? Why?
I need a clear and nice explanation to these questions to resolve my confusions. Please help.
Let me have a go.
This depends on what you mean by "non-differentiable". The first definition that makes sense here is that PyTorch doesn't know how to compute gradients. If you try to compute gradients nevertheless, this will raise an error. The two possible scenarios are:
a) You're using a custom PyTorch operation for which gradients have not been implemented, e.g. torch.svd(). In that case you will get a TypeError:
import torch
from torch.autograd import Function
from torch.autograd import Variable
A = Variable(torch.randn(10,10), requires_grad=True)
u, s, v = torch.svd(A) # raises TypeError
b) You have implemented your own operation, but did not define backward(). In this case, you will get a NotImplementedError:
class my_function(Function): # forgot to define backward()
def forward(self, x):
return 2 * x
A = Variable(torch.randn(10,10))
B = my_function()(A)
C = torch.sum(B)
C.backward() # will raise NotImplementedError
The second definition that makes sense is "mathematically non-differentiable". Clearly, an operation which is mathematically not differentiable should either not have a backward() method implemented or a sensible sub-gradient. Consider for example torch.abs() whose backward() method returns the subgradient 0 at 0:
A = Variable(torch.Tensor([-1,0,1]),requires_grad=True)
B = torch.abs(A)
B.backward(torch.Tensor([1,1,1]))
A.grad.data
For these cases, you should refer to the PyTorch documentation directly and dig out the backward() method of the respective operation directly.
It doesn't matter. The use of requires_gradis to avoid unnecessary computations of gradients for subgraphs. If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Variables didn’t require gradients.
Since, there are most likely some Variables (for example parameters of a subclass of nn.Module()), your loss Variable will also require gradients automatically. However, you should notice that exactly for how requires_grad works (see above again), you can only change requires_grad for leaf variables of your graph anyway.
All the custom PyTorch loss functions, are subclasses of _Loss which is a subclass of nn.Module. See here. If you'd like to stick to this convention, you should subclass _Loss when defining your custom loss function. Apart from consistency, one advantage is that your subclass will raise an AssertionError, if you haven't marked your target variables as volatile or requires_grad = False. Another advantage is that you can nest your loss function in nn.Sequential(), because its a nn.Module I would recommend this approach for these reasons.

Resources