Callback in JAX fori_loop - jax

Is it possible to have callbacks inside a function passed to JAX fori_loop?
In my case, the callback will save to disk some of the intermediate results produced in the function.
I tried something like this:
def callback(values):
# do something
def diffusion_loop(i, args):
# do something
callback(results)
return results
final_result, _ = jax.lax.fori_loop(0, num_steps, diffusion_loop, (arg1, arg2))
But then if I use final_result or whatever was saved from the callback I get an error like this
UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1,4,64,64] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scanned_fun at /usr/local/lib/python3.8/dist-packages/jax/_src/lax/control_flow/loops.py:1606 traced for scan.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<timed exec>:81 (<module>)
<timed exec>:67 (diffusion_loop)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:264 (step)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:472 (step_plms)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

It sounds like you want to do a callback to the host that is impure (i.e. it has a side-effect of saving values to disk) and does not return any values to the runtime. For that, one option is jax.experimental.host_callback.id_tap, discussed in the docs here.
For example:
import jax
from jax.experimental import host_callback as hcb
def callback(value, transforms):
# do something
print(f"callback: {value}")
def diffusion_loop(i, args):
hcb.id_tap(callback, i)
return args
args = (1, 2)
result, _ = jax.lax.fori_loop(0, 5, diffusion_loop, args)
callback: 0
callback: 1
callback: 2
callback: 3
callback: 4

Related

Shared memory and how to access a global variable from within a class in Python, with multiprocessing?

I am currently developing some code that deals with big multidimensional arrays. Of course, Python gets very slow if you try to perform these computations in a serialized manner. Therefore, I got into code parallelization, and one of the possible solutions I found has to do with the multiprocessing library.
What I have come up with so far is first dividing the big array in smaller chunks and then do some operation on each of those chunks in a parallel fashion, using a Pool of workers from multiprocessing. For that to be efficient and based on this answer I believe that I should use a shared memory array object defined as a global variable, to avoid copying it every time a process from the pool is called.
Here I add some minimal example of what I'm trying to do, to illustrate the issue:
import numpy as np
from functools import partial
import multiprocessing as mp
import ctypes
class Trials:
# Perform computation along first dimension of shared array, representing the chunks
def Compute(i, shared_array):
shared_array[i] = shared_array[i] + 2
# The function you actually call
def DoSomething(self):
# Initializer function for Pool, should define the global variable shared_array
# I have also tried putting this function outside DoSomething, as a part of the class,
# with the same results
def initialize(base, State):
global shared_array
shared_array = np.ctypeslib.as_array(base.get_obj()).reshape(125, 100, 100) + State
base = mp.Array(ctypes.c_float, 125*100*100) # Create base array
state = np.random.rand(125,100,100) # Create seed
# Initialize pool of workers and perform calculations
with mp.Pool(processes = 10,
initializer = initialize,
initargs = (base, state,)) as pool:
run = partial(self.Compute,
shared_array = shared_array) # Here the error says that shared_array is not defined
pool.map(run, np.arange(125))
pool.close()
pool.join()
print(shared_array)
if __name__ == '__main__':
Trials = Trials()
Trials.DoSomething()
The trouble I am encountering is that when I define the partial function, I get the following error:
NameError: name 'shared_array' is not defined
For what I understand, I think that means that I cannot access the global variable shared_array. I'm sure that the initialize function is executing, as putting a print statement inside of it gives back a result in the terminal.
What am I doing incorrectly, is there any way to solve this issue?

Should python unittest mock calls copy passed arguments by reference?

The following code snippets were all run with python 3.7.11.
I came across some unexpected behavior with the unittest.mock Mock class. I wrote a unit test where a Mock was called with specific arguments to cover the case that the the real object, being mocked for, was expected to be called with during real runtime. The tests passed so I pushed a build to a real running device to discover a bug where not all of the arguments were being passed to the real object method. I quickly found my bug however to me it initially looked like my unit test should have failed when it instead succeeded. Below are some simplified examples of that situation. What I am curious about is if this behavior should be considered a bug or an error in understanding on my part.
from `unittest.mock` import Mock
mock = Mock()
l = [1]
mock(l)
l.append(2)
mock.call_args
# Output: call([1,2])
# rather than call([1])
id(l) == id(mock.call_args[0][0])
# Output: True
# This means the l object and latest call_args reference occupy the same space in memory
This copy by reference behavior is confusing because when a function is called in the same procedure, it is not expected that the function would be called with arguments appended to the object after the call.
def print_var(x):
print(x)
l = [1]
print_var(l)
# Output: 1
l.append(2)
# print_var never be called with [1,2]
Would it make sense for call_args to use deepcopy to emulate the behavior I was expecting?
Found a solution on the unittest.mock examples
from unittest.mock import Mock
from copy import deepcopy
class CopyingMock(MagicMock):
def __call__(self, *args, **kwargs):
args = deepcopy(args)
kwargs = deepcopy(kwargs)
return super().__call__(*args, **kwargs)
c = CopyingMock(return_value=None)
arg = set()
c(arg)
arg.add(1)
c.assert_called_with(set())
c.assert_called_with(arg)
# Traceback (most recent call last):
# ...
# AssertionError: Expected call: mock({1})
# Actual call: mock(set())

Please explain the working of following code and some questions mentioned below

import keyboard
def fun1(r):
print(r.name)
keyboard.on_press(fun1)
the code is a simple key logger, what exactly is happening here?
what I understood so far is:
import
function definition
keyboard.on_press is called
Please explain the following things
what exactly is keyboard.on_press(fun1) is passing to fun1()
why having a parameter is important for fun1
what if i dont wanna make a function just wanna put my code in [keyboard.on_press("here")], why would it not be possible.
Few more questions
with keyboard.Listener(
on_press=on_press,
on_release=on_release) as listener:
listener.join()
whats up with "with" statement here?
what does .join() {what it means t have it joined it to main thread}
where have we written on_press=on_press {why not just once}
I dont know if this query is dependent on version of python or version of module.
I am using all latest versions.
So far I read the documentation on
https://pynput.readthedocs.io/en/latest/keyboard.html
and googles all my questions but could not find the easy explanation.
Here's some code with comments to help explain the syntax:
import keyboard
def somefunc(msg):
print(msg) # here
def fun1(r): # r is keyboard event
print(type(r)) # <class 'keyboard._keyboard_event.KeyboardEvent'>
print(r.name) # key text
somefunc("here") # call some other function
keyboard.on_press(fun1) # pass a reference to the other function (on_press will call whatever function we pass), other function must have single parameter
while True: pass # keep script running
The With keyword just ensures the the object is closed properly even if there is an error.
with keyboard.Listener(
on_press=on_press, # named parameter is on_press, we are passing a reference to function on_press
on_release=on_release) as listener: # named parameter is on_release, we are passing a reference to function on_release
listener.join() # wait for listener thread to finish
# Is shortcut syntax for:
listener = keyboard.Listener(on_press=on_press, on_release=on_release)
try:
........
finally:
listener.stop() # always do this, even if error
listener.join()
About the double variable. It looks weird, I know. The code is using named parameters.
# sample function
def myfunc(x): # the parameter name is x
print(x)
# We can call myfunc several ways:
myfunc(1) # pass value directly
myfunc(x=1) # specify parameter to assign
x=1 # local variable
myfunc(x) # pass the local variable as parameter
myfunc(x=x) # specify parameter to assign, pass local variable # this is the double text you're seeing
Concerning the callback function, you can pass a function reference to another function.
def MyCallback(a,b): # custom function
print(a,b)
def SomeEventHandler(f): # parameter is reference to another function
f(1,2) # call passed in function, it must have these parameters
SomeEventHandler(MyCallback) # pass custom function to handler
I hope this clears things up a bit.

How to write a function that sums a list using parallel computing?

I am trying to write a Python function for fast calculation of the sum of a list, using parallel computing. Initially I tried to use the Python multithreading library, but then I noticed that all threads run on the same CPU, so there is no speed gain, so I switched to using multiprocessing. In the first version I made the list a global variable:
from multiprocessing import Pool
array = 100000000*[1]
def sumPart(fromTo:tuple):
return sum(array[fromTo[0]:fromTo[1]])
with Pool(2) as pool:
print(sum(pool.map(sumPart, [(0,len(array)//2), (len(array)//2,len(array))])))
This worked well and returned the correct sum after about half the time of a serial computation.
But then I wanted to make it a function that accepts the array as an argument:
def parallelSum(theArray):
def sumPartLocal(fromTo: tuple):
return sum(theArray[fromTo[0]:fromTo[1]])
with Pool(2) as pool:
return (sum(pool.map(sumPartLocal, [(0, len(theArray) // 2), (len(theArray) // 2, len(theArray))])))
Here I got an error:
AttributeError: Can't pickle local object 'parallelSum.<locals>.sumPartLocal'
What is the correct way to write this function?
When scheduling jobs to a Python Pool you need to ensure both the function and it's arguments can be serialized as they will be transferred over a pipe.
Python uses the pickle protocol to serialize its objects. You can see what can be pickled in the module documentation. In your case, you are facing this limitation.
functions defined at the top level of a module (using def, not lambda)
Under the hood, the Pool is sending a string with the function name and its parameters. The Python interpreter in the child process looks for that function name in the module and fails to find it as it's nested in the scope of another function parallelSum.
Move sumPartLocal outside parallelSum and everything will be fine.
I believe you are hitting this, or see the documentation
What you could do is leave def sumPartLocal at module level, and pass theArray as third component of your tuple so that would be fromTo[2] inside the sumPartLocal function.
Example:
from multiprocessing import Pool
def sumPartLocal(fromTo: tuple):
return sum(fromTo[2][fromTo[0]:fromTo[1]])
def parallelSum(theArray):
with Pool(2) as pool:
return (sum
(pool.map
(sumPartLocal, [
(0, len(theArray) // 2, theArray),
(len(theArray) // 2, len(theArray), theArray)
]
)
)
)
if __name__ == '__main__':
theArray = 100000000*[1]
s = parallelSum(theArray)
print(s)
[EDIT 15-Dec-2017 based on comments]
Anyone who is thinking of multi-threading in python, I strongly recommend reading up about the Global Interpreter Lock
Also, some good answers on this question here on SO

Wrapping all possible method calls of a class in a try/except block

I'm trying to wrap all methods of an existing Class (not of my creation) into a try/except suite. It could be any Class, but I'll use the pandas.DataFrame class here as a practical example.
So if the invoked method succeeds, we simply move on. But if it should generate an exception, it is appended to a list for later inspection/discovery (although the below example just issues a print statement for simplicity).
(Note that the kinds of data-related exceptions that can occur when a method on the instance is invoked, isn't yet known; and that's the reason for this exercise: discovery).
This post was quite helpful (particularly #martineau Python-3 answer), but I'm having trouble adapting it. Below, I expected the second call to the (wrapped) info() method to emit print output but, sadly, it doesn't.
#!/usr/bin/env python3
import functools, types, pandas
def method_wrapper(method):
#functools.wraps(method)
def wrapper(*args, **kwargs): #Note: args[0] points to 'self'.
try:
print('Calling: {}.{}()... '.format(args[0].__class__.__name__,
method.__name__))
return method(*args, **kwargs)
except Exception:
print('Exception: %r' % sys.exc_info()) # Something trivial.
#<Actual code would append that exception info to a list>.
return wrapper
class MetaClass(type):
def __new__(mcs, class_name, base_classes, classDict):
newClassDict = {}
for attributeName, attribute in classDict.items():
if type(attribute) == types.FunctionType: # Replace it with a
attribute = method_wrapper(attribute) # decorated version.
newClassDict[attributeName] = attribute
return type.__new__(mcs, class_name, base_classes, newClassDict)
class WrappedDataFrame2(MetaClass('WrappedDataFrame',
(pandas.DataFrame, object,), {}),
metaclass=type):
pass
print('Unwrapped pandas.DataFrame().info():')
pandas.DataFrame().info()
print('\n\nWrapped pandas.DataFrame().info():')
WrappedDataFrame2().info()
print()
This outputs:
Unwrapped pandas.DataFrame().info():
<class 'pandas.core.frame.DataFrame'>
Index: 0 entries
Empty DataFrame
Wrapped pandas.DataFrame().info(): <-- Missing print statement after this line.
<class '__main__.WrappedDataFrame2'>
Index: 0 entries
Empty WrappedDataFrame2
In summary,...
>>> unwrapped_object.someMethod(...)
# Should be mirrored by ...
>>> wrapping_object.someMethod(...)
# Including signature, docstring, etc. (i.e. all attributes); except that it
# executes inside a try/except suite (so I can catch exceptions generically).
long time no see. ;-) In fact it's been such a long time you may no longer care, but in case you (or others) do...
Here's something I think will do what you want. I've never answered your question before now because I don't have pandas installed on my system. However, today I decided to see if there was a workaround for not having it and created a trivial dummy module to mock it (only as far as I needed). Here's the only thing in it:
mockpandas.py:
""" Fake pandas module. """
class DataFrame:
def info(self):
print('pandas.DataFrame.info() called')
raise RuntimeError('Exception raised')
Below is code that seems to do what you need by implementing #Blckknght's suggestion of iterating through the MRO—but ignores the limitations noted in his answer that could arise from doing it that way). It ain't pretty, but as I said, it seems to work with at least the mocked pandas library I created.
import functools
import mockpandas as pandas # mock the library
import sys
import traceback
import types
def method_wrapper(method):
#functools.wraps(method)
def wrapper(*args, **kwargs): # Note: args[0] points to 'self'.
try:
print('Calling: {}.{}()... '.format(args[0].__class__.__name__,
method.__name__))
return method(*args, **kwargs)
except Exception:
print('An exception occurred in the wrapped method {}.{}()'.format(
args[0].__class__.__name__, method.__name__))
traceback.print_exc(file=sys.stdout)
# (Actual code would append that exception info to a list)
return wrapper
class MetaClass(type):
def __new__(meta, class_name, base_classes, classDict):
""" See if any of the base classes were created by with_metaclass() function. """
marker = None
for base in base_classes:
if hasattr(base, '_marker'):
marker = getattr(base, '_marker') # remember class name of temp base class
break # quit looking
if class_name == marker: # temporary base class being created by with_metaclass()?
return type.__new__(meta, class_name, base_classes, classDict)
# Temporarily create an unmodified version of class so it's MRO can be used below.
TempClass = type.__new__(meta, 'TempClass', base_classes, classDict)
newClassDict = {}
for cls in TempClass.mro():
for attributeName, attribute in cls.__dict__.items():
if isinstance(attribute, types.FunctionType):
# Convert it to a decorated version.
attribute = method_wrapper(attribute)
newClassDict[attributeName] = attribute
return type.__new__(meta, class_name, base_classes, newClassDict)
def with_metaclass(meta, classname, bases):
""" Create a class with the supplied bases and metaclass, that has been tagged with a
special '_marker' attribute.
"""
return type.__new__(meta, classname, bases, {'_marker': classname})
class WrappedDataFrame2(
with_metaclass(MetaClass, 'WrappedDataFrame', (pandas.DataFrame, object))):
pass
print('Unwrapped pandas.DataFrame().info():')
try:
pandas.DataFrame().info()
except RuntimeError:
print(' RuntimeError exception was raised as expected')
print('\n\nWrapped pandas.DataFrame().info():')
WrappedDataFrame2().info()
Output:
Unwrapped pandas.DataFrame().info():
pandas.DataFrame.info() called
RuntimeError exception was raised as expected
Wrapped pandas.DataFrame().info():
Calling: WrappedDataFrame2.info()...
pandas.DataFrame.info() called
An exception occurred in the wrapped method WrappedDataFrame2.info()
Traceback (most recent call last):
File "test.py", line 16, in wrapper
return method(*args, **kwargs)
File "mockpandas.py", line 9, in info
raise RuntimeError('Exception raised')
RuntimeError: Exception raised
As the above illustrates, the method_wrapper() decoratored version is being used by methods of the wrapped class.
Your metaclass only applies your decorator to the methods defined in classes that are instances of it. It doesn't decorate inherited methods, since they're not in the classDict.
I'm not sure there's a good way to make it work. You could try iterating through the MRO and wrapping all the inherited methods as well as your own, but I suspect you'd get into trouble if there were multiple levels of inheritance after you start using MetaClass (as each level will decorate the already decorated methods of the previous class).

Resources