Compute efficiently Hessian matrices in JAX - jax

In JAX's Quickstart tutorial I found that the Hessian matrix can be computed efficiently for a differentiable function fun using the following lines of code:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
However, one can compute the Hessian also by computing the following:
def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))
Here is a minimal working example:
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))
def main():
comp_hessian()
if __name__ == "__main__":
main()
I would like to know which approach is best to use and when? I also would like to know if it is possible to use grad() to compute the Hessian? And how does grad() differ from jacfwd and jacrev?

The answer to your question is within the JAX documentation; see for example this section: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev
To quote its discussion of jacrev and jacfwd:
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.
and further down,
To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function 𝑓:ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇𝑓:ℝⁿ→ℝⁿ), which is where forward-mode wins out.
Since your function looks like 𝑓:ℝⁿ→ℝ, then jit(jacfwd(jacrev(fun))) is likely the most efficient approach.
As for why you can't implement a hessian with grad, this is because grad is only designed for derivatives of functions with scalar outputs. A hessian by definition is a composition of vector-valued jacobians, not a composition of scalar gradients.

Related

Custom loss function in pytorch 1.10.1

I am struggeling with defining a custom loss function for pytorch 1.10.1. My model outputs a float ranging from -1 to +1. The target values are floats of arbitrary range. The loss should be a sum of pruducts if the sign between the model output and target is different.
I have searched the internet for quite some hours, but it seems there have been some changes to pytorch throughout the last versions, so I don't really know which example would best fit to my use case and pytorch 1.10.1.
Here is my approach so far:
class Loss(torch.nn.Module):
#staticmethod
def forward(self, output, target) -> Tensor:
loss = 0.0
for i in range(len(target)):
o = output[i,0]
t = target[i]
l = o * t
if l<0: #if different sign
loss -= l
return loss
Question:
Should I subclass torch.nn.Module or torch.autograd.Function?
Do I need to define #staticmethod?
On some examples, I saw ctx instead of self being used and invocations of ctx.save_for_backward etc. Do I need this? What is its purpose?
When subclassing torch.nn.Module, my code complains: 'Tensor' object has no attribute 'children'. What am I missing?
When subclassing torch.autograd.Function, my code complains about not having a backward function defined. How should my backward function look like?
Custom loss functions can be as simple as a python function. You can simplify this a bit:
def custom_loss(output, target):
prod = output[:,0]*target
return -prod[prod<0].sum()

matrix multiplication for complex numbers in PyTorch

I am trying to multiply two complex matrices in PyTorch and it seems the torch.matmul functions is not added yet to PyTorch library for complex numbers.
Do you have any recommendation or is there another method to multiply complex matrices in PyTorch?
Currently torch.matmul is not supported for complex tensors such as ComplexFloatTensor but you could do something as compact as the following code:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real # t2.real - t1.imag # t2.imag, t1.real # t2.imag + t1.imag # t2.real),dim=2))
When possible avoid using for loops as these will result in much slower implementations.
Vectorization is achieved by using built-in methods as demonstrated in the code I have attached.
For example, your code takes roughly 6.1s on CPU while the vectorized version takes only 101ms (~60 times faster) for 2 random complex matrices with dimensions 1000 X 1000.
Update:
Since PyTorch 1.7.0 (as #EduardoReis mentioned) you can do matrix multiplication between complex matrices similarly to real-valued matrices as follows:
t1 # t2
(for t1, t2 complex matrices).
I implemented this function for pytorch.matmul for complex numbers using torch.mv and it's working fine for time-being:
def matmul_complex(t1, t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n), dtype=torch.cfloat)
t_total = torch.empty((m,n), dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
t_final = torch.reshape(t_total, (m,n))
return t_final
I am new to PyTorch, so please correct me if I am wrong.

How to evaluate loss only on elements satisfying a condition pytorch

I'm working on a regression problem in pytorch. I get good results on my evaluation set, but I want to make sure it's not because I have many small elements and less large ones. Therefore, I would like to check whether I get similar loss for the large elements (eg. elements > 0.01). I use mse loss.
Can anyone pls suggest a way of doing so?
Thanks!
You can zero-out loss for smaller elements (assuming size of elements is based on your regression target), you can implement your own loss function like this:
import torch
class CustomMSE:
def __init__(self, threshold=0.01, reduction=torch.mean):
self.threshold = threshold
self.reduction = reduction
def __call__(self, output, target):
# Do not reduce, so you get per-element loss
loss = torch.nn.functional.mse_loss(output, target, reduction="none")
loss[target < self.threshold] = 0
return self.reduction(loss)
criterion = CustomMSE()
You can use it just like torch.nn.MSELoss, this should give you an overall idea.

What's the difference between sum and torch.sum for a torch Tensor?

I get the same results when using either the python sum or torch.sum so why did torch implement a sum function? Is there a difference between them?
nothing, torch.sum calls tensor.sum and python's sum calls __add__ (or __radd__ when needed) which calls tensor.sum again
so the only difference is in the number of function calls, and tensor.sum() should be the fastest (when you have small tensors and the function call's overhead is considerable)
It appears python's sum can take generators as input, whereas torch.sum cannot:
import torch
print( sum( torch.ones(1)*k for k in torch.arange(10)))
returns tensor([45.]), whereas:
print( torch.sum( torch.ones(1)*k for k in torch.arange(10)))
raises TypeError: sum(): argument 'input' (position 1) must be Tensor, not generator
I'm assuming that pyTorch's backpropagation would get in trouble with lazy evaluation of the generator, but not sure about that, yet.

pymc3 theano function usage

I'm trying to define a complex custom likelihood function using pymc3. The likelihood function involves a lot of iteration, and therefore I'm trying to use theano's scan method to define iteration directly within theano. Here's a greatly simplified example that illustrates the challenge that I'm facing. The (fake) likelihood function I'm trying to define is simply the sum of two pymc3 random variables, p and theta. Of course, I could simply return p+theta, but the actual likelihood function I'm trying to write is more complicated, and I believe I need to use theano.scan since it involves a lot of iteration.
import pymc3 as pm
from pymc3 import Model, Uniform, DensityDist
import theano.tensor as T
import theano
import numpy as np
### theano test
theano.config.compute_test_value = 'raise'
X = np.asarray([[1.0,2.0,3.0],[1.0,2.0,3.0]])
### pymc3 implementation
with Model() as bg_model:
p = pm.Uniform('p', lower = 0, upper = 1)
theta = pm.Uniform('theta', lower = 0, upper = .2)
def logp(X):
f = p+theta
print("f",f)
get_ll = theano.function(name='get_ll',inputs = [p, theta], outputs = f)
print("p keys ",p.__dict__.keys())
print("theta keys ",theta.__dict__.keys())
print("p name ",p.name,"p.type ",p.type,"type(p)",type(p),"p.tag",p.tag)
result=get_ll(p, theta)
print("result",result)
return result
y = pm.DensityDist('y', logp, observed = X) # Nx4 y = f(f,x,tx,n | p, theta)
When I run this, I get the error:
TypeError: ('Bad input argument to theano function with name "get_ll" at index 0(0-based)', 'Expected an array-like object, but found a Variable: maybe you are trying to call a function on a (possibly shared) variable instead of a numeric array?')
I understand that the issue occurs in line
result=get_ll(p, theta)
because p and theta are of type pymc3.TransformedRV, and that the input to a theano function needs to be a scalar number of a simple numpy array. However, a pymc3 TransformedRV does not seem to have any obvious way of obtaining the current value of the random variable itself.
Is it possible to define a log likelihood function that involves the use of a theano function that takes as input a pymc3 random variable?
The problem is that your th.function get_ll is a compiled theano function, which takes as input numerical arrays. Instead, pymc3 is sending it a symbolic variable (theano tensor). That's why you're getting the error.
As to your solution, you're right in saying that just returning p+theta is the way to go. If you have scans and whatnot in your logp, then you would return the scan variable of interest; there is no need to compile a theano function here. For example, if you wanted to add 1 to each element of a vector (as an impractical toy example), you would do:
def logp(X):
the_sum, the_sum_upd = th.scan(lambda x: x+1, sequences=[X])
return the_sum
That being said, if you need gradients, you would need to calculate your the_sum variable in a theano Op and provide a grad() method along with it (you can see a toy example of that on the answer here). If you do not need gradients, you might be better off doing everything in python (or C, numba, cython, for performance) and using the as_op decorator.

Resources