Extract the diagonal elements of the Hessian in a neural network in Jax - jax

I have a PyTree params (in my case a nested dictionary) containing my parameters of a neural network. My goal is to compute the diagonal entries of the Hessian of a loss function with respect to the parameters and store it in a PyTree of the same structure as the parameters.
When I call jax.hessian(loss_fn)(params, data), I get a (as expected) an even more nested dictionary with the full Hessian.
How can I transform this dictionary to get the desired PyTree with diagonal entries?
To be more concrete: Lets say I have only 1 layer in my network and paramsis given by
params:
'linear':
'w': DeviceArray() of shape [5 x 1]
'b': DeviceArray() of shape [1]
The returned Hessian has the keys and shape given by
hessian:
'linear':
'b':
'linear':
'b': (1, 1),
'w': (1, 5, 1),
'w':
'linear':
'b': (5, 1, 1),
'w': (5, 1, 5, 1)
As far as I understand it, I need the entries
jnp.diag(hessian['linear']['b']['linear']['b'])
as the diagonal hessian for the bias and
jnp.diag(jnp.squeeze(hessian['linear']['w']['linear']['w']))
as the diagonal hessian for the weights. (However, the squeeze may only work for 1 dim outputs...)
How can I automate this transformation in order to work for more complex models with multiple layers?
I know that this does not scale to huge networks, I need it for testing purposes of optimizers.

I ran into the exact same problem. Unfortunately, working with Pytrees in Jax can be awkward. I was also looking at a way to construct the diagonal Hessian entry-for-entry, since that could yield a practical method.
I now have the following:
def ravelled_diagonal_indices(dims: Sequence[int]) -> jnp.ndarray:
# Get the indices for the diagonal elements of a flattened square matrix.
return (dims[0] + 1) * jnp.arange(dims[0])
# Alias to reduce clutter.
_diag_idx = ravelled_diagonal_indices
def tree_matrix_diagonal(tree: Any, reference: Optional[Any] = None) -> Any:
"""Utility function for extracting the diagonal of a Pytree of jax.numpy.array objects.
The Pytree is assumed to be square in its children and in its array objects.
Parameters
----------
tree : Any
Pytree of jax.numpy.array objects for which the number of Pytree leaves and
the sizes of each constituent array is square.
reference : Any, default = None
The intended structure for the diagonal of `tree`. For example, this can be
the Pytree with which `tree` could have been created through e.g., an outer-product
or the Hessian of a function.
Returns
-------
diag : Any
Pytree containing the flattened diagonals of `tree` if no reference was provided.
Otherwise, the diagonal elements are shaped according to the structure of `reference`.
"""
flat = jax.tree_leaves(tree)
h = jax.numpy.sqrt(len(flat)).astype(int)
_idx = _diag_idx((h,))
block_diag = [flat[i] for i in _idx]
flat_diagonal = lambda w: w.ravel()[_diag_idx((jax.numpy.sqrt(w.size).astype(int),))]
diag = jax.tree_map(flat_diagonal, block_diag)
if reference is not None:
# Reshape the diagonal Pytree to reference Pytree structure and shape
diag_tree = jax.tree_unflatten(jax.tree_structure(reference), diag)
diag = jax.tree_multimap(lambda a, b: a.reshape(jax.numpy.shape(b)), diag_tree, reference)
return diag
When I try this out on the Hessian of a very simple MLP:
params
>> {'dense/~/affine': {'weights': DeviceArray([[ 1. , 1. ],
[ 0.546326 , -0.77997607]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[ 1. ],
[-0.5155028],
[ 0.9487318]], dtype=float32)}}
hessian
>> {'dense/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[-0.02324889, 0.04278728],
[ 0.00814307, -0.01498652]],
[[ 0.04278728, -0.07874574],
[-0.01498652, 0.0275812 ]]],
[[[ 0.00814307, -0.01498652],
[-0.00285216, 0.00524912]],
[[-0.01498652, 0.0275812 ],
[ 0.00524912, -0.00966049]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[ 0.04509945],
[ 0.15897979],
[ 0.05742025]],
[[-0.08300105],
[-0.06711845],
[ 0.01683405]]],
[[[-0.01579637],
[-0.05568369],
[-0.02011181]],
[[ 0.02907166],
[ 0.02350867],
[-0.00589623]]]]], dtype=float32)}}},
'dense_1/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[ 0.04509945, -0.08300105],
[-0.01579637, 0.02907165]]],
[[[ 0.15897979, -0.06711845],
[-0.0556837 , 0.02350867]]],
[[[ 0.05742024, 0.01683406],
[-0.02011181, -0.00589624]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[-0.08748633],
[-0.07074545],
[-0.11138687]]],
[[[-0.07074545],
[-0.05720801],
[-0.09007253]]],
[[[-0.11138687],
[-0.09007251],
[-0.14181684]]]]], dtype=float32)}}}}
Then, the function returns:
tree_matrix_diagonal(hessian, reference=params)
>> {'dense/~/affine': {'weights': DeviceArray([[-0.02324889, -0.07874574],
[-0.00285216, -0.00966049]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[-0.08748633],
[-0.05720801],
[-0.14181684]], dtype=float32)}}
Upon visual inspection, you can see that the returned elements are indeed the diagonal elements of hessian cast to the canonical structure of params.
Funnily enough, for the Gauss-Newton approximation to the Hessian the procedure is much simpler. Simply take the element-wise square of the Jacobians :).

Related

Interleaving FFT real & complex parts in a PyTorch tensor

I have a use-case where I have to do FFT for a given tensor as. Here, FFT is applied to each of the 10 rows, in a column-wise manner which gives the dimension (10, 11) post FFT.
# Random data-
x = torch.rand((10, 20))
# Compute RFFT of 'x'-
x_fft = torch.fft.rfft(x)
# Sanity check-
x.shape, x_fft.shape
# (torch.Size([10, 20]), torch.Size([10, 11]))
# FFT for the first 2 rows are-
x_fft[:2, :]
'''
tensor([[12.2561+0.0000j, 0.7551-1.2075j, 1.1119-0.0458j, -0.2814-1.5266j,
1.4083-0.7302j, 0.6648+0.3311j, 0.3969+0.0632j, -0.8031-0.1904j,
-0.4206+0.9066j, -0.2149+0.9160j, 0.4800+0.0000j],
[ 9.8967+0.0000j, -0.5100-0.2377j, -0.6344+2.2406j, 0.4584-1.0705j,
0.2235+0.4788j, -0.3923+0.8205j, -1.0372-0.0292j, -1.6368+0.5517j,
1.5093+0.0419j, 0.5755-1.2133j, 2.9269+0.0000j]])
'''
# The goal is to have for each row, 1-D vector (of size = 11) as follows:
# So, for first row, the desired 1-D vector (size = 11) is-
[12.2561, 0.0000, 0.7551, -1.2075, 1.1119, -0.0458, -0.2814, -1.5266,
1.4083, -0.7302, 0.6648, 0.3311, 0.3969, 0.0632, -0.8031, -0.1904,
-0.4206, 0.9066, -0.2149, 0.9160, 0.4800, 0.0000]
'''
Here, you are taking the real and imaginary components and placing them adjacent to each other.
Adjacent means:
[a_1_real, a_1_imag, a_2_real, a_2_imag, a_3_real, a_3_imag, ....., a_n_real, a_n_imag]
Since for each row, you get 11 FFT complex numbers, a_n = a_11.
How to go about it?
Your question seems to come down to: how to interleave two tensors together. Given x and y the two tensors. You can do so with a combination of transpose and reshape.
>>> torch.stack((x,y),1).transpose(1,2).reshape(2,-1)
tensor([[ 1.1547e+01, 0.0000e+00, 1.3786e+00, -8.1970e-01, -3.2118e-02,
-2.3900e-02, -3.2898e-01, -3.4610e-01, -1.7916e-01, 1.2308e+00,
-5.4203e-01, 1.2580e-01, 8.5273e-01, 8.9980e-01, -2.7096e+00,
-3.8060e-01, 3.0016e-01, -4.5240e-01, -7.7809e-02, 4.5630e-01,
-4.5805e-03, 0.0000e+00],
[ 1.1106e+01, 0.0000e+00, 1.3362e-01, 1.3830e-01, -7.4233e-01,
7.7570e-01, -9.9461e-01, 1.0834e+00, 1.6952e+00, 5.2920e-01,
-1.1884e+00, -2.5970e-01, -8.7958e-01, 4.3180e-01, -9.3039e-01,
8.8130e-01, -1.0048e+00, 1.2823e+00, 2.0595e-01, -6.5170e-01,
1.7209e+00, 0.0000e+00]])

slice Pytorch tensors which are saved in a list

I have the following code segment to generate random samples. The generated samples is a list, where each entry of the list is a tensor. Each tensor has two elements. I would like to extract the first element from all tensors in the list; and extract the second element from all tensors in the list as well. How to perform this kind of tensor slice operation
import torch
import pyro.distributions as dist
num_samples = 250
# note that both covariance matrices are diagonal
mu1 = torch.tensor([0., 5.])
sig1 = torch.tensor([[2., 0.], [0., 3.]])
dist1 = dist.MultivariateNormal(mu1, sig1)
samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]
samples1
I'd recommend torch.cat with a list comprehension:
col1 = torch.cat([t[0] for t in samples1])
col2 = torch.cat([t[1] for t in samples1])
Docs for torch.cat: https://pytorch.org/docs/stable/generated/torch.cat.html
ALTERNATIVELY
You could turn your list of 1D tensors into a single big 2D tensor using torch.stack, then do a normal slice:
samples1_t = torch.stack(samples1)
col1 = samples1_t[:, 0] # : means all rows
col2 = samples1_t[:, 1]
Docs for torch.stack: https://pytorch.org/docs/stable/generated/torch.stack.html
I should mention PyTorch tensors come with unpacking out of the box, this means you can unpack the first axis into multiple variables without additional considerations. Here torch.stack will output a tensor of shape (rows, cols), we just need to transpose it to (cols, rows) and unpack:
>>> c1, c2 = torch.stack(samples1).T
So you get c1 and c2 shaped (rows,):
>>> c1
tensor([0.6433, 0.4667, 0.6811, 0.2006, 0.6623, 0.7033])
>>> c2
tensor([0.2963, 0.2335, 0.6803, 0.1575, 0.9420, 0.6963])
Other answers that suggest .stack() or .cat() are perfectly fine from PyTorch perspective.
However, since the context of the question involves pyro, may I add the following:
Since you are doing IID samples
[pyro.sample('samples1', dist1) for _ in range(num_samples)]
A better way to do it with pyro is
dist1 = dist.MultivariateNormal(mu1, sig1).expand([num_samples])
This tells pyro that the distribution is batched with a batch size of num_samples. Sampling from this will produce
>> dist1.sample()
tensor([[-0.8712, 6.6087],
[ 1.6076, -0.2939],
[ 1.4526, 6.1777],
...
[-0.0168, 7.5085],
[-1.6382, 2.1878]])
Now its easy to solve your original question. Just slice it like
samples = dist1.sample()
samples[:, 0] # all first elements
samples[:, 1] # all second elements

Understanding L2-norm output for 3D tensor - TensorFlow2

For Python 3.8 and TensorFlow 2.5, I have a 3-D tensor of shape (3, 3, 3) where the goal is to compute the L2-norm for each of the three (3, 3) square matrices. The code that I came up with is:
a = tf.random.normal(shape = (3, 3, 3))
a.shape
# TensorShape([3, 3, 3])
a.numpy()
'''
array([[[-0.30071023, 0.9958398 , -0.77897555],
[-1.4251901 , 0.8463568 , -0.6138699 ],
[ 0.23176959, -2.1303613 , 0.01905925]],
[[-1.0487134 , -0.36724553, -1.0881581 ],
[-0.12025198, 0.20973174, -2.1444907 ],
[ 1.4264063 , -1.5857363 , 0.31582597]],
[[ 0.8316077 , -0.7645084 , 1.5271858 ],
[-0.95836663, -1.868056 , -0.04956183],
[-0.16384012, -0.18928945, 1.04647 ]]], dtype=float32)
'''
I am using axis = 2 since the 3rd axis should contain three 3x3 square matrices. The output I get is:
tf.math.reduce_euclidean_norm(input_tensor = a, axis = 2).numpy()
'''
array([[1.299587 , 1.7675754, 2.1430166],
[1.5552354, 2.158075 , 2.15614 ],
[1.8995634, 2.1001325, 1.0759989]], dtype=float32)
'''
How are these values computed? The formula for computing L2-norm is this. What am I missing?
Also, I was expecting three L2-norm values, one for each of the three (3, 3) matrices. The code I have to achieve this is:
tf.math.reduce_euclidean_norm(a[0]).numpy()
# 3.0668826
tf.math.reduce_euclidean_norm(a[1]).numpy()
# 3.4241767
tf.math.reduce_euclidean_norm(a[2]).numpy()
# 3.0293021
Is there any better way to get this without having to explicitly refer to each indices of tensor 'a'?
Thanks!
The formula you linked for computing the L2 norm looks correct. What you have is basically this:
np.sqrt(np.sum((a[0]**2)))
# 3.0668826
np.sqrt(np.sum((a[1]**2)))
# 3.4241767
np.sqrt(np.sum((a[2]**2)))
# 3.0293021
This can be vectorized by the following:
np.sqrt(np.sum(a**2, axis=(1,2)))
Output:
array([3.0668826, 3.4241767, 3.0293021], dtype=float32)
Which is effectively the same as using np.lingalg.norm (or tf.math.reduce_euclidean_norm if you want to use tensorflow)
np.linalg.norm(a, ord=None, axis=(1,2))
Output:
array([3.0668826, 3.4241767, 3.0293021], dtype=float32)
The default keyword ord=None is for calculating the L2 norm per the documentation. The axis keyword is to specify which dimensions we want to reduce which should be clear from the first code snippet.

Loop over tensor dimension 0 (NoneType) with second tensor values

I have a tensor a, I'd like to loop over the rows and index values based on another tensor l. i.e. l suggests the length of the vector I need.
sess = tf.InteractiveSession()
a = tf.constant(np.random.rand(3,4)) # shape=(3,4)
a.eval()
Out:
array([[0.35879311, 0.35347166, 0.31525201, 0.24089784],
[0.47296348, 0.96773956, 0.61336239, 0.6093023 ],
[0.42492552, 0.2556728 , 0.86135674, 0.86679779]])
l = tf.constant(np.array([3,2,4])) # shape=(3,)
l.eval()
Out:
array([3, 2, 4])
Expected output:
[array([0.35879311, 0.35347166, 0.31525201]),
array([0.47296348, 0.96773956]),
array([0.42492552, 0.2556728 , 0.86135674, 0.86679779])]
The tricky part is the fact that a could have None as first dimension since it's what is usually defined as batch size through placeholder.
I can not just use mask and condition as below since I need to compute the variance of each row individually.
condition = tf.sequence_mask(l, tf.reduce_max(l))
a_true = tf.boolean_mask(a, condition)
a_true
Out:
array([0.35879311, 0.35347166, 0.31525201, 0.47296348, 0.96773956,
0.42492552, 0.2556728 , 0.86135674, 0.86679779])
I also tried to use tf.map_fn but can't get it to work.
elems = (a, l)
tf.map_fn(lambda x: x[0][:x[1]], elems)
Any help will be highly appreciated!
TensorArray object can store tensors of different shapes. However, it is still not that simple. Take a look at this example that does what you want using tf.while_loop() with tf.TensorArray and tf.slice() function:
import tensorflow as tf
import numpy as np
batch_data = np.array([[0.35879311, 0.35347166, 0.31525201, 0.24089784],
[0.47296348, 0.96773956, 0.61336239, 0.6093023 ],
[0.42492552, 0.2556728 , 0.86135674, 0.86679779]])
batch_idx = np.array([3, 2, 4]).reshape(-1, 1)
x = tf.placeholder(tf.float32, shape=(None, 4))
idx = tf.placeholder(tf.int32, shape=(None, 1))
n_items = tf.shape(x)[0]
init_ary = tf.TensorArray(dtype=tf.float32,
size=n_items,
infer_shape=False)
def _first_n(i, ta):
ta = ta.write(i, tf.slice(input_=x[i],
begin=tf.convert_to_tensor([0], tf.int32),
size=idx[i]))
return i+1, ta
_, first_n = tf.while_loop(lambda i, ta: i < n_items,
_first_n,
[0, init_ary])
first_n = [first_n.read(i) # <-- extracts the tensors
for i in range(batch_data.shape[0])] # that you're looking for
with tf.Session() as sess:
res = sess.run(first_n, feed_dict={x:batch_data, idx:batch_idx})
print(res)
# [array([0.3587931 , 0.35347167, 0.315252 ], dtype=float32),
# array([0.47296348, 0.9677396 ], dtype=float32),
# array([0.4249255 , 0.2556728 , 0.86135674, 0.8667978 ], dtype=float32)]
Note
We still had to use batch_size to extract elements one by one from first_n TensorArray using read() method. We can't use any other method that returns Tensor because we have rows of different sizes (except TensorArray.concat method but it will return all elements stacked in one dimension).
If TensorArray will have less elements than index you pass to TensorArray.read(index) you will get InvalidArgumentError.
You can't use tf.map_fn because it returns a tensor that must have all elements of the same shape.
The task is simpler if you only need to compute variances of the first n elements of each row (without actually gather elements of different sizes together). In this case we could directly compute variance of sliced tensor, put it to TensorArray and then stack it to tensor:
n_items = tf.shape(x)[0]
init_ary = tf.TensorArray(dtype=tf.float32,
size=n_items,
infer_shape=False)
def _variances(i, ta, begin=tf.convert_to_tensor([0], tf.int32)):
mean, varian = tf.nn.moments(
tf.slice(input_=x[i], begin=begin, size=idx[i]),
axes=[0]) # <-- compute variance
ta = ta.write(i, varian) # <-- write variance of each row to `TensorArray`
return i+1, ta
_, variances = tf.while_loop(lambda i, ta: i < n_items,
_variances,
[ 0, init_ary])
variances = variances.stack() # <-- read from `TensorArray` to `Tensor`
with tf.Session() as sess:
res = sess.run(variances, feed_dict={x:batch_data, idx:batch_idx})
print(res) # [0.0003761 0.06120085 0.07217039]

Keras 'Error when checking input' when trying to predict multiple values

I have a net with a length 4 input vector, length 2 output vector. I am trying to predict multiple inputs simultaneously. If I just want to predict one, I would do the following and it works:
in = numpy.array( [ [1,2,3,4] ] )
self.model.predict(in)
# prediction = [ [1,2] ]
However, when I try to pass in multiple inputs I get ValueError: Error when checking input: expected dense_1_input to have shape (4,) but got array with shape (1,)
in = numpy.array( [
[1,2,3,4],
[1,2,3,4]
]
)
#OR
in = numpy.array( [
[ [1,2,3,4] ],
[ [1,2,3,4] ]
]
)
self.model.predict(in)
#ERR
What am I doing wrong?
Edit:
Code =
model = Sequential()
model.add(Dense(24, input_dim=4, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(4, activation='linear'))
model.compile(loss='mse',
optimizer=Adam(lr=self.learning_rate))
print(batch_arr[:,3][0])
predictions = self.model.predict(batch_arr[:,3][0])
print(predictions)
print(batch_arr[:,3])
predictions = model.predict(batch_arr[:,3])
Output =
[[-0.00441936 -0.20398824 -0.08134908 0.09739554]]
[[ 0.01860509 -0.01136071]]
[array([[-0.00441936, -0.20398824, -0.08134908, 0.09739554]])
array([[-0.00517939, 0.38975933, -0.11951023, -0.9718224 ]])
array([[0.00272119, 0.0025476 , 0.002645 , 0.03973542]])
array([[-0.00421809, -0.01006362, -0.07795483, -0.16971247]])
array([[-0.00904593, 0.19332681, -0.10655871, -0.64757587]])
array([[ 0.00654432, 0.00347247, -0.15332555, -0.47302148]])
array([[-0.01921821, -0.17354519, -0.20207744, -0.58569029]])
array([[ 0.00661377, 0.20038962, -0.16278598, -0.80983334]])
array([[-0.00348096, 0.18171964, -0.07072813, -0.38913168]])
array([[-0.01268919, -0.00548544, -0.08286095, -0.27108632]])
array([[ 0.01077598, -0.19254374, -0.004982 , 0.33175341]])
array([[-4.37101750e-04, -5.68196965e-01, -1.99532537e-01,
1.10581883e-01]])
array([[ 0.00657382, -0.19263146, -0.00402872, 0.33368607]])
array([[ 0.00677398, 0.19760551, -0.00076944, -0.25153403]])
array([[ 0.00261579, 0.19642629, -0.13894668, -0.71894379]])
array([[-0.0221003 , 0.37477368, -0.03765055, -0.63564477]])
array([[-0.0110009 , 0.37599703, -0.0574645 , -0.66318148]])
array([[ 0.00277214, 0.19763152, 0.00343971, -0.25211181]])
array([[-9.31810654e-05, -2.06245307e-01, -8.09019674e-02,
1.47356796e-01]])
array([[ 0.00709025, -0.37636771, -0.19725323, -0.11396513]])
array([[ 0.00015344, -0.01233088, -0.07851076, -0.11956039]])
array([[ 0.01077811, -0.18439307, -0.19043179, -0.34107231]])
array([[-0.01460483, 0.18019651, -0.05036345, -0.35505252]])
array([[-0.0127989 , 0.19071515, -0.08828268, -0.58871071]])
array([[ 0.01072609, 0.00249456, -0.00580012, 0.0409061 ]])
array([[ 0.01062156, 0.00782762, -0.17898265, -0.57245695]])
array([[-0.01180104, -0.37085843, -0.1973209 , -0.23782701]])
array([[-0.00849912, -0.00780031, -0.07940117, -0.21980343]])
array([[ 0.00672477, 0.00246062, -0.00160252, 0.04165408]])
array([[-0.02268911, -0.36534914, -0.21379125, -0.36284594]])
array([[-0.00865513, -0.20170279, -0.08379724, 0.0468145 ]])
array([[-0.0256848 , 0.17922475, -0.03098346, -0.33335449]])]
#ERR
Edit: When I print out the shape of batch_arr[:,3] I get (32,), not (32,4) as I expected. Thus I'm guess the numpy array does not know the shape of its inner arrays. Is there an easy way to fix that? It might be the root of the problem
The issue was the way that I had created my numpy array. I created it with indices of variable size, and thus it didn't know it was shaped (32,4), only that it was (32,). Reformulating the logic to ensure that the array is always a set width from the beginning allowed the array to be a (32,4), which allowed the prediction to work.

Resources