Struggling to understand nested vmaps in JAX - jax

I just about understand unnested vmaps, but try as I may, and I have tried my darnedest, nested vmaps continue to elude me. Take the snippet from this text for example
I don't understand what the axis are in this case. Is the nested vmap(kernel, (0, None)) some sort of partial function application? Why is the function mapped twice? Can someone please explain what is going on behind the scene in other words. What does a nested vmap desugar to?? All the answers that I have found are variants of the same curt explanation: mapping over both axis, which I am struggling with.

Each time vmap is applied, it maps over a single axis. So say for simplicity that you have a function that takes two scalars and outputs a scalar:
def f(x, y):
assert jnp.ndim(x) == jnp.ndim(y) == 0 # x and y are scalars
return x + y
print(f(1, 2))
# 0
If you want to apply this function to a single x value and an array of y values, you can do this with vmap:
f_mapped_over_x = jax.vmap(f, in_axes=(0, None))
x = jnp.arange(5)
print(f_mapped_over_x(x, 1))
# [1 2 3 4 5]
in_axes=(0, None) means that it is mapped along the leading axis of the first argument, x, and there is no mapping of the second argument, y.
Likewise, if you want to apply this function to a single x value and an array of y values, you can specify this via in_axes:
f_mapped_over_y = jax.vmap(f, in_axes=(None, 0))
y = jnp.arange(5, 10)
print(f_mapped_over_y(1, y))
# [ 6 7 8 9 10]
If you wish to map the function over both arrays at once, you can do this by specifying in_axes=(0, 0), or equivalently in_axes=0:
f_mapped_over_x_and_y = jax.vmap(f, in_axes=(0, 0))
print(f_mapped_over_x_and_y(x, y))
# [ 5 7 9 11 13]
But suppose you want to map first over x, then over y, to get a sort of "outer-product" version of the function. You can do this via a nested vmap, first mapping over just x, then mapping over just y:
f_mapped_over_x_then_y = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
print(f_mapped_over_x_then_y(x, y))
# [[ 5 6 7 8 9]
# [ 6 7 8 9 10]
# [ 7 8 9 10 11]
# [ 8 9 10 11 12]
# [ 9 10 11 12 13]]
The nesting of vmaps is what lets you map over two axes separately.

Related

How to merge multiple tuples or lists in to dictionary using loops?

Here is my code to merge all tuple in to dictionary,
x = (1,2,3)
y = ('car',"truck","plane")
z=("merc","scania","boeing")
products={}
for i in x,y,z:
products[x[i]]= {y[i]:z[i]}
output:
error:
6 for i in x,y,z:
----> 7 products[x[i]]= {y[i]:z[i]}
8
9 print(products)
TypeError: tuple indices must be integers or slices, not a tuple
Now if i use indexing method inside loop for identifying positions like below code,
for i in x,y,z:
products[x[0]]= {y[0]:z[0]}
print(products)
out:
{1: {'car': 'merc'}}
here, I could only create what I need but only for a specified index how do create a complete dictionary using multiple lists/tuples??
is it also possible to use Zip & map functions?
Use zip to iterate over your separate iterables/tuples in parallel
list(zip(x, y, z)) # [(1, 'car', 'merc'), (2, 'truck', 'scania'), (3, 'plane', 'boeing')]
x = (1, 2, 3)
y = ("car", "truck", "plane")
z = ("merc", "scania", "boeing")
products = {i: {k: v} for i, k, v in zip(x, y, z)}
print(products) # {1: {'car': 'merc'}, 2: {'truck': 'scania'}, 3: {'plane': 'boeing'}}
You should use integer as indices.
x = (1,2,3)
y = ('car',"truck","plane")
z=("merc","scania","boeing")
products={}
for i in range(len(x)):
products[x[i]]= {y[i]:z[i]}
This should solve your problem
To add for above answer, I'm posting a solution using map,
x = (1,2,3)
y = ('car',"truck","plane")
z=("merc","scania","boeing")
products=dict(map(lambda x,y,z:(x,{y:z}),x,y,z))
print(products)

SymPy result Filtering

I was recently working on a CodeForce problem
So, I was using SymPy to solve this.
My code is :
from sympy import *
x,y = symbols("x,y", integer = True)
m,n = input().split(" ")
sol = solve([x**2 + y - int(n), y**2 + x - int(m)], [x, y])
print(sol)
What I wanted to do:
Filter only Positive and integer value from SymPy
Ex: If I put 14 28 in the terminal it will give me tons of result, but I just want it to show [(5, 3)]
I don't think that this is the intended way to solve the code force problem (I think you're just supposed to loop over the possible values for one of the variables).
I'll show how to make use of SymPy here anyway though. Your problem is a diophantine system of equations. Although SymPy has a diophantine solver it only works for individual equations rather than systems.
Usually the idea of using a CAS for something like this though is to symbolically find something like a general result that then helps you to write faster concrete numerical code. Here are your equations with m and n as arbitrary symbols:
In [62]: x, y, m, n = symbols('x, y, m, n')
In [63]: eqs = [x**2 + y - n, y**2 + x - m]
Using the polynomial resultant we can eliminate either x or y from this system to obtain a quartic polynomial for the remaining variable:
In [31]: py = resultant(eqs[0], eqs[1], x)
In [32]: py
Out[32]:
2 2 4
m - 2⋅m⋅y - n + y + y
While there is a quartic general formula that SymPy can use (if you use solve or roots here) it is too complicated to be useful for a problem like the one that you are describing. Instead though the rational root theorem tells us that an integer root for y must be a divisor of the constant term:
In [33]: py.coeff(y, 0)
Out[33]:
2
m - n
Therefore the possible values for y are:
In [64]: yvals = divisors(py.coeff(y, 0).subs({m:14, n:28}))
In [65]: yvals
Out[65]: [1, 2, 3, 4, 6, 7, 8, 12, 14, 21, 24, 28, 42, 56, 84, 168]
Since x is m - y**2 the corresponding values for x are:
In [66]: solve(eqs[1], x)
Out[66]:
⎡ 2⎤
⎣m - y ⎦
In [67]: xvals = [14 - yv**2 for yv in yvals]
In [68]: xvals
Out[68]: [13, 10, 5, -2, -22, -35, -50, -130, -182, -427, -562, -770, -1750, -3122, -7042, -28210]
The candidate solutions are then given by:
In [69]: candidates = [(xv, yv) for xv, yv in zip(xvals, yvals) if xv > 0]
In [70]: candidates
Out[70]: [(13, 1), (10, 2), (5, 3)]
From there you can test which values are solutions:
In [74]: eqsmn = [eq.subs({m:14, n:28}) for eq in eqs]
In [75]: [c for c in candidates if all(eq.subs(zip([x,y],c))==0 for eq in eqsmn)]
Out[75]: [(5, 3)]
The algorithmically minded will probably see from the above example how to make a much more efficient way of implementing the solver.
I've figured out the answer to my question ! At first, I was trying to filter the result from solve(). But there is an easy way to do this.
Pseudo code:
solve() gives the intersection point of both Parabolic Equations as a List
I just need to filter() the other types of values. Which in my case is <sympy.core.add.Add>
def rem(_list):
return list(filter(lambda v: type(v) != Add, _list))
Yes, You can also use type(v) == int
Final code:
from sympy import *
# the other values were <sympy.core.add.Add> type. So, I just defined a function to filterOUT these specific types from my list.
def rem(_list):
return list(filter(lambda v: type(v) != Add, _list))
x,y = symbols("x,y", integer = True, negative = False)
output = []
m,n = input().split(' ')
# I need to solve these 2 equations separately. Otherwise, my defined function will not work without loop.
solX = rem(solve((x+(int(n)-x**2)**2 - int(m)), x))
solY = rem(solve((int(m) - y**2)**2 + y - int(n), y))
if len(solX) == 0 or len(solY) == 0:
print(0)
else:
output.extend(solX) # using "Extend" to add multiple values in the list.
output.extend(solY)
print(int((len(output))/2)) # Obviously, result will come in pairs. So, I need to divide the length of the list by 2.
Why I used this way :
I tried to solve it by algorithmic way, but it still had some float numbers. I just wanted to skip the loop thing here again !
As sympy solve() has already found the values. So, I skipped the other way and focused on filtering !
Sadly, code force compiler shows a runtime error! I guess it can't import sympy. However, it works fine in VSCode.

Defining a function to calculate mean-differences at specific array size

I have an array:
arr = np.array([1,2,3,4,5,6,7,8]
I want to define a function to calculate the difference of means of the elements of this array but at a given length.
For example:
diff_avg(arr, size=2)
Expected Result:
[-2, -2]
because:
((1+2)/2) - ((3+4)/2)) = -2 -> first 4 elements because size is 2, so 2 groups of 2 elements
((5+6)/2) - ((7+8)/2)) = -2 -> last 4 elements
if size=3
then:
output: [-3]
because:
((1+2+3)/3) - ((4+5+6)/3)) = -3 -> first 6 elements
what I did so far:
def diff_avg(first_group, second_group, size):
results =[]
x = np.mean(first_group) - np.mean(second_group)
results.append(x)
return results
I don't know how to add the size parameter
I can use the first size elements with arr[:size] but how to get the next size elements.
Does anyone can help me?
First, truncate the array to remove the extra items:
size = 3
sized_array = arr[:arr.size // (size * 2) * (size * 2)]
# array([1, 2, 3, 4, 5, 6])
Next, reshape the sized array and get the means:
means = sized_array.reshape([2, size, -1]).mean(axis=1)
# array([[2.], [5.]])
Finally, take the differences:
means[0] - means[1]
#array([-3.])

Finding smallest numbers Python numpy list

I have a Python 3 list which contains arbitrary number of numpy arrays of varying size/shape. The problem is to remove the smallest p% (where, p = 20%, say) of number (in terms of magnitude) in the list to zero.
Example code:
l = []
l.append(np.random.normal(1.5, 1, size = (4, 3)))
l.append(np.random.normal(1, 1, size = (4, 4)))
l.append(np.random.normal(1.8, 2, size = (2, 4)))
for x in l:
print(x.shape)
'''
(4, 3)
(4, 4)
(2, 4)
'''
How can I remove smallest p% of numbers from 'l' Python list "globally", this means that for all of the numpy arrays contained within the list 'l', it will remove the smallest p% of the smallest numbers (in terms of magnitude) to zero?
I am using Python 3.8 and numpy 1.18.
Thanks!
Toy example:
l
'''
[array([[ 0.95400011, 1.95433152, 0.40316605],
[ 1.34477354, 3.24612127, 1.54138912],
[ 1.158594 , 0.77954464, 0.4600395 ],
[-0.03092974, 3.55349303, 0.85526191]]),
array([[ 2.33613547, 0.12361808, 0.27620035, 0.70452795],
[ 0.76989846, -0.28613191, 1.90050011, 2.73843595],
[ 0.13510186, 0.91035556, 1.42402321, 0.60582303],
[-0.13655066, 2.4881577 , 2.0882935 , 1.40347429]]),
array([[-1.63365952, 1.2616223 , 0.86784273, -0.34538727],
[ 1.37161267, 2.4570491 , -0.72419948, 1.91873343]])]
'''
'l' has 36 numbers in it. Now 20% of 36 = 7.2 or rounded down = 7. So the idea is that 7 smallest magnitude numbers out of 36 numbers are removed by masking them to zero!
you can try the following. It looks for the threshold value and update the list in place to 0 when the value is under the threshold.
Let me know if you need more details
import numpy as np
l = []
l.append(np.random.normal(1.5, 1, size = (4, 3)))
l.append(np.random.normal(1, 1, size = (4, 4)))
l.append(np.random.normal(1.8, 2, size = (2, 4)))
acc = []
p = 20 #percentile to update to 0
for x in l:
acc.append(x.flatten())
threshold = np.percentile(np.concatenate(acc),p)
for x in l:
x[x < threshold] = 0
You can use this:
p = 20 #percentile to remove
lower = np.percentile(np.hstack([x.flatten() for x in l]), p)
for x in l:
x[x<lower] = 0
You basically stack all numbers into single array and using np.percentile, find the threshold for p% lower bound and then filter arrays using the lower threshold.

Is it possible to unpack a list of tuples with a list comprehension?

I'd like to unpack tuples from a list of tuples into individual variables using list comprehension. E.g. how to do the second print with a list comprehension instead of an explicit loop:
tuples = [(2, 4), (3, 9), (4, 16)]
# Print in direct order OK
print(('Squares:' + ' {} --> {}' * len(tuples)).format(
*[v for t in tuples for v in t]))
# Print in reverse order not possible
print('Square roots:', end='')
for t in tuples:
print (' {} --> {}'.format(t[1], t[0]), end='')
print()
>>> Squares: 2 --> 4 3 --> 9 4 --> 16
>>> Square roots: 4 --> 2 9 --> 3 16 --> 4
Is it possible to replace the second print loop by a list comprehension?
Feel free to simplify further if appropriate.
In python-3.x print is a function, so you can indeed write:
[print (' {} --> {}'.format(*t[::-1]), end='') for t in tuples]
but this is probably worse than using a for loop, since now you allocate memory for every iteration. In case the number of iterations is huge, you will construct a huge list filled with Nones.
It produces:
>>> tuples = [(2, 4), (3, 9), (4, 16)]
>>> [print (' {} --> {}'.format(*t[::-1]), end='') for t in tuples]
4 --> 2 9 --> 3 16 --> 4[None, None, None]
The [None, None, None] is not printed, but simply the result of the list comprehension.
But that being said, we do not need list comprehension, we can use ''.join(..) (with a list or generator`, like:
print('Squares:'+''.join(' {} --> {}'.format(*t) for t in tuples))
print('Square roots:'+''.join(' {} --> {}'.format(*t[::-1]) for t in tuples))
this produces:
>>> print('Squares:'+''.join(' {} --> {}'.format(*t) for t in tuples))
Squares: 2 --> 4 3 --> 9 4 --> 16
>>> print('Square roots:'+''.join(' {} --> {}'.format(*t[::-1]) for t in tuples))
Square roots: 4 --> 2 9 --> 3 16 --> 4

Resources