I am attempting to jit compile a python function, and use a optional argument to change the arguments of another function call.
I think where jit might be tripping up is that the default value of the optional argument is None, and jit doesn't know how to handle that, or at least doesn't know how to handle it when it changes to a numpy array. See below for a rough overview:
#jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):
if optionalArg is not None:
out=otherFunc(arg1,optionalArg)
else:
out=otherFunc(arg1)
return out
Where optionalArg is either None, or a numpy array
One solution would be to turn this into three functions as shown below, but this feels kinda janky and I don't like it, especially because speed is very important for this task.
def foo(otherFunc,arg1,optionalArg=None):
if optionalArg is not None:
out=func1(otherFunc,arg1,optionalArg)
else:
out=func2(otherFunc,arg1)
return out
#jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
out=otherFunc(arg1,optionalArg)
return out
#jit(nopython=True)
def func2(otherFunc,arg1):
out=otherFunc(arg1)
return out
Note that other stuff is happening besides just calling otherFunc that makes using jit worth it, but I'm almost certain that is not where the problem is since this was working before without the optionalArg portion, so I have decided not to include it.
For those of you that are curious its runge-kutta order 4 implementation with optional extra parameters to pass to the differential equation. If you want to see the whole thing just ask.
The traceback is rather long but here is some of it:
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):
File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:
This continues...
inte.rk4 is the equiavlent of foo, de2 is otherFunc, y0, 0.001 and 200 are just values, that I swaped out for arg1 in my problem description above, and vals is optionalArg.
A similar thing happens when I try to run this with the vals parameter omitted:
ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):
File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
ysExp=inte.rk4(deExp,y0,0.001,200)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:
This continues...
If you see the documentation here, you can specify the optional type arguments explicitly in Numba. For example (this is the same example from documentation):
>>> #jit((optional(intp),))
... def f(x):
... return x is not None
...
>>> f(0)
True
>>> f(None)
False
Additionally, based on the conversation going on this Github issue you can use the following workaround to implement optional keyword. I have modified the code from the solution provided in the github issue to suit your example:
from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np
np_arr = np.asarray([1,2])
spec = OrderedDict()
spec['x'] = int32
#jitclass(spec)
class Foo(object):
def __init__(self, x):
self.x = x
def otherFunc(self, optionalArg):
if optionalArg is None:
return self.x + 10
else:
return len(optionalArg)
#njit
def useOtherFunc(arg1, optArg):
foo = Foo(arg1)
print(foo.otherFunc(optArg))
arg1 = 5
useOtherFunc(arg1, np_arr) # Output: 2
useOtherFunc(arg1, None) # Output : 15
See this colab notebook for the example shown above.
Related
Context: I want to create attributes of an object class in parallel by distributing them in the available cores. This question was answered in this post here by using the python Multiprocessing Pool.
The MRE for my task is the following using Pyomo 6.4.1v:
from pyomo.environ import *
import os
import multiprocessing
from multiprocessing import Pool
from multiprocessing.managers import BaseManager, NamespaceProxy
import types
class ObjProxy(NamespaceProxy):
"""Returns a proxy instance for any user defined data-type. The proxy instance will have the namespace and
functions of the data-type (except private/protected callables/attributes). Furthermore, the proxy will be
pickable and can its state can be shared among different processes. """
def __getattr__(self, name):
result = super().__getattr__(name)
if isinstance(result, types.MethodType):
def wrapper(*args, **kwargs):
return self._callmethod(name, args, kwargs)
return wrapper
return result
#classmethod
def create(cls, *args, **kwargs):
# Register class
class_str = cls.__name__
BaseManager.register(class_str, cls, ObjProxy, exposed=tuple(dir(cls)))
# Start a manager process
manager = BaseManager()
manager.start()
# Create and return this proxy instance. Using this proxy allows sharing of state between processes.
inst = eval("manager.{}(*args, **kwargs)".format(class_str))
return inst
ConcreteModel.create = create
class A:
def __init__(self):
self.model = ConcreteModel.create()
def do_something(self, var):
if var == 'var1':
self.model.var1 = var
elif var == 'var2':
self.model.var2 = var
else:
print('other var.')
def do_something2(self, model, var_name, var_init):
model.add_component(var_name,var_init)
def init_var(self):
print('Sequentially')
self.do_something('var1')
self.do_something('test')
print(self.model.var1)
print(vars(self.model).keys())
# Trying to create the attributes in parallel
print('\nParallel')
self.__sets_list = [(self.model,'time',Set(initialize = [x for x in range(1,13)])),
(self.model,'customers',Set(initialize = ['c1','c2','c3'])),
(self.model,'finish_bulks',Set(initialize = ['b1','b2','b3','b4'])),
(self.model,'fermentation_types',Set(initialize = ['ft1','ft2','ft3','ft4'])),
(self.model,'fermenters',Set(initialize = ['f1','f2','f3'])),
(self.model,'ferm_plants',Set(initialize = ['fp1','fp2','fp3','fp4'])),
(self.model,'plants',Set(initialize = ['p1','p2','p3','p4','p5'])),
(self.model,'gran_plants',Set(initialize = ['gp1','gp2','gp3','gp4','gp4']))]
with Pool(7) as pool:
pool.starmap(self.do_something2,self.__sets_list)
self.model.time.pprint()
self.model.customers.pprint()
def main(): # The main part run from another file
obj = A()
obj.init_var()
# Call other methods to create other attributes and the solver step.
# The other methods are similar to do_something2() just changing the var_init to Var() and Constraint().
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
main = main()
Ouput
Sequentially
other var.
var1
dict_keys(['_tls', '_idset', '_token', '_id', '_manager', '_serializer', '_Client', '_owned_by_manager', '_authkey', '_close'])
Parallel
WARNING: Element gp4 already exists in Set gran_plants; no action taken
time : Size=1, Index=None, Ordered=Insertion
Key : Dimen : Domain : Size : Members
None : 1 : Any : 12 : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
customers : Size=1, Index=None, Ordered=Insertion
Key : Dimen : Domain : Size : Members
None : 1 : Any : 3 : {'c1', 'c2', 'c3'}
I change the number of parallel processes for testing, but it raises different errors, and other times it runs without errors. This is confusing for me, and I did not figure out what is the problem behind it. I did not find another post that had a similar problem, but I saw some posts discussing that pickle does not handle large data. So, the errors that sometimes I gotcha are the following:
Error 1
Unserializable message: Traceback (most recent call last):
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/managers.py", line 300, in serve_client
send(msg)
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/connection.py", line 211, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
SystemError: <method 'dump' of '_pickle.Pickler' objects> returned NULL without setting an error
Error 2
Unserializable message: Traceback (most recent call last):
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/managers.py", line 300, in serve_client
send(msg)
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/connection.py", line 211, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
RuntimeError: dictionary changed size during iteration
Error 3
*** Reference count error detected: an attempt was made to deallocate the type 32727 (? ***
*** Reference count error detected: an attempt was made to deallocate the type 32727 (? ***
*** Reference count error detected: an attempt was made to deallocate the type 32727 (? ***
Unserializable message: Traceback (most recent call last):
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/managers.py", line 300, in serve_client
send(msg)
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/connection.py", line 211, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
numpy.core._exceptions._ArrayMemoryError: <unprintble MemoryError object>
Error 4
Unserializable message: Traceback (most recent call last):
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/managers.py", line 300, in serve_client
send(msg)
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/connection.py", line 211, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/.../anaconda3/envs/.../lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'WeakSet.__init__.<locals>._remove'
So, there are different errors, and it looks like it is not stable. I hope that someone has had and solved this problem. Furthermore, if someone has implemented other strategies for this task, please, feel free to post your answer in this issue here
Tkx.
I'm recently in a domain adaptation program which involves using gradreverselayer.
This is what I'm using.
class GradReverse(Function):
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
It seems there's nothing wrong with it.
However, when I run my code, it raised this error.
Traceback (most recent call last):
File "main.py", line 125, in <module>
class_criterion, domain_criterion, optimizer, trainloader1, valloader, trainloader2, testloader2)
File "/content/drive/MyDrive/TRANSFER/Train.py", line 325, in train_original
src_domain_loss.backward(retain_graph=True)
File "/usr/local/lib/python3.7/dist-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py", line 147, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py", line 89, in apply
return self._forward_cls.backward(self, *args) # type: ignore
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py", line 201, in backward
raise NotImplementedError("You must implement the backward function for custom"
NotImplementedError: You must implement the backward function for custom autograd.Function.
I can't figure it out. So I add
src_domain_loss=torch.tensor(src_domain_loss, requires_grad=True)
And there's no warning.
Howerver, when I check my net's grad, it's like this:
-->name: domain_classifier.c_fc1.weight -->grad_requirs: True -->grad_value: None
-->name: domain_classifier.c_fc1.bias -->grad_requirs: True -->grad_value: None
-->name: domain_classifier.c_fc2.weight -->grad_requirs: True -->grad_value: None
-->name: domain_classifier.c_fc2.bias -->grad_requirs: True -->grad_value: None
And once I remove the gradreverselayer, it works well. No warning, and the grad is fine.
It really bothers me. I wish someone could help me.
Thank you in advance.
Ok, I've solved the problem. It's really a stupid mistake.
The problem lies on '#staticmethod'!
The GradReverseLayer should be like:
class GradReverse(Function):
#staticmethod
def forward(ctx, x, alpha, **kwargs:None):
ctx.alpha = alpha
return x.view_as(x)
#staticmethod
def backward(ctx, grad_output):
output = grad_output * -ctx.alpha
return output, None
I always think it's just a note! And I delete it every time!
But I still don't know why
src_domain_loss=torch.tensor(src_domain_loss, requires_grad=True)
this would lead to zero grad in backward process. Does anybody have some ideas? It will be nice if we can discuss about it lol.
I am trying to have a simple subclass of OrderedDict that gets created by a Pool then returned.
It seems that the pickling process when returning the created object to the pool tries to re-instantiate the object and fails due to the required additional argument in the __init__ function.
This is a minimal (non) working example:
from collections import OrderedDict
from multiprocessing import Pool
class Obj1(OrderedDict):
def __init__(self, x, *args, **kwargs):
super().__init__(*args, **kwargs)
self.x = x
def task(x):
obj1 = Obj1(x)
return obj1
if __name__ == '__main__':
with Pool(1) as pool:
for x in pool.imap_unordered(task, (1,2,3)):
print(x.x)
If I do this I get the following error.
Exception in thread Thread-3:
Traceback (most recent call last):
File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/usr/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.6/multiprocessing/pool.py", line 463, in _handle_results
task = get()
File "/usr/lib/python3.6/multiprocessing/connection.py", line 251, in recv
return _ForkingPickler.loads(buf.getbuffer())
TypeError: init() missing 1 required positional argument: 'x'
Again this fails when the task functions returns to the pool and I guess the object gets pickled?
If I changed OrderedDict by a simple dict it works flawlessly....
I have a workaround to use kwargs and retrieve the attribute of interest but I am stumped about the error to start with. Any ideas?
You can define __getstate__() and __setstate__() methods for your class.
In those functions you can make sure that x is handled as well. For example:
def __getstate__(self):
return self.x, self.items()
def __setstate__(self, state):
self.x = state[0]
self.update(state[1])
BTW, from CPython 3.6 there is no reason to use OrderedDict, since dictionary order is insertion order. This was originally an implementation detail in CPython. In Python 3.7 it was made part of the language.
I've created a class property to handle image data and then I'm trying to assign data to this property using a 3-element tuple. For some reason Python seems to think that my tuple contains just one element. Any ideas about what's going on here?
The property setter is defined as follows:
#data.setter
def data(self, *args):
image_array, dtype, sizes = args
if image_array is None:
self._data = np.empty(sizes, dtype)
else:
self._data = np.array(image_array, dtype)
self._set_color_data()
And upon execution, I get the following output:
test = (image_temp, np.uint8, sizes)
print(len(test))
>>> 3
self.image5d.data = test
Traceback (most recent call last):
File "C:\***\Python36\lib\tkinter\__init__.py", line 1699, in __call__
return self.func(*args)
File "c:***\mmCIAD\mmciad.py", line 88, in open_file
self.image5d.data = test
File "c:\***\mmCIAD\mmciad.py", line 172, in data
image_array, dtype, sizes = args
ValueError: not enough values to unpack (expected 3, got 1)
Any help will be much appreciated!
You should use:
def data(self, args):
Currently *args just packs the tuple in the first item so you get:
args = [(v1, v2, v3)]
I need to write a function f. It takes a function g and a set of *args and **kwargs as input. It's supposed to call the input function g with the arguments and return its result. One requirement: if the given arguments are not accepted by the function, I should raise a custom exception instead of letting python trigger it's own TypeError. How do I know that the given arguments cannot be used to call the given function successfully?
Example:
def f(g, *args, **kwargs):
# How to implement bad_arguments()?
if bad_arguments(g, *args, **kwargs):
raise CustomTypeError()
else:
return g(*args, **kwargs)
My first guess is to use the inspect module. I know I can look at all the expected arguments by inspecting the function. But then how do I determine whether or not the provided arguments fulfill the requirement, specially considering there might be variable args (i.e. *args and **kwargs) in g's signature and the calling arguments might specify positional arguments by names. It seems complicated enough to me that manual logic just seem to be unreliable here.
EDIT
Please see my comments to this question. I hope they clarify my question a bit more.
Also, to ask the question in a different way: you know how python checks the arguments before actually invoking your function body and raises TypeError if it finds something mismatching (e.g. when an argument is not provided, or when a named parameter which is not on the argument list is provided)? I basically want to do the same logic here then raise my CustomTypeError.
** EDIT in reply to #SigmaPiEpsilon:**
I seemed to have found a bug in your example code (which is totally fine considering it was just an example to illustrate your idea).
My point being, manual reproduction of this standard python logic might be erroneous. And that's why I prefer a more systematic way, if you will.
from inspect import signature
def check_args(f, *args, **kwargs):
sig = signature(f)
parameters = {"POSITIONAL_OR_KEYWORD" : [], "KEYWORD_ONLY" : []}
for elem in sig.parameters.values():
parameters[str(elem.kind)].append(elem.name)
print(parameters)
if len(args) > len(parameters["POSITIONAL_OR_KEYWORD"]):
print("More positional arguments")
elif len(kwargs) != len(parameters["KEYWORD_ONLY"]):
print("Insufficient keyword arguments")
elif set(kwargs.keys()) != set(parameters["KEYWORD_ONLY"]):
print("Provided keywords %s does not match function keywords %s" %(list(kwargs.keys()),parameters["KEYWORD_ONLY"]))
else:
z,u = f(*args,**kwargs)
def f(x, y, z, kw="Hello"):
u = x + y
v = x/y
print(z)
print(kw)
return z, u
f(3, 4, 5, kw='a')
check_args(f, 3, 4, 5, kw='a')
The output is:
$ python test.py
5
a
{'POSITIONAL_OR_KEYWORD': ['x', 'y', 'z', 'kw'], 'KEYWORD_ONLY': []}
Insufficient keyword arguments
This is a difficult problem to solve generally due to the flexibility of python function arguments (positional, keywords, keyword only etc). As such it is probably better to solve specific cases that suits your applications. For details look into python Signature and Parameter objects of python inspect module. A crude example is provided below to illustrate one approach that uses the bind() method of the signature object. You can adapt this to fit your specific example.
Edit: Added a version of the check in line with OP's requirement. Check previous edits for more customized checking of arguments
from inspect import signature
def check_args(f, *args, **kwargs):
sig = signature(f)
try:
bound = sig.bind(*args,**kwargs)
except TypeError:
return False
else:
return True
def f(g, *args, **kwargs):
if not check_args(g, *args, **kwargs):
raise Exception("Bad Arguments")
else:
return g(*args, **kwargs)
def g(x, y,*,z,kw="Hello"):
u = x + y
v = x*y
print(z)
print(kw)
return u
f(g,3,4,z = 5, kw = "Hello")
Test in python 3.4
$ python3.4 -i function_check2.py
5
Hello
>>> f(g,3,4,5,z = 5)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "function_check2.py", line 16, in f
raise Exception("Bad Arguments")
Exception: Bad Arguments
>>> f(g,3,4,z = 5)
5
Hello
7
>>> f(g,3,4,kw = "Hello")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "function_check2.py", line 16, in f
raise Exception("Bad Arguments")
Exception: Bad Arguments