About function re-compilation each time we call jax.jit - jax

I am a newbie to jax. When I'm reading to the documentation, I'm confused about the caching behavior of jit.
In the caching section, it says that "Avoid calling jax.jit inside loops. Doing that effectively creates a new f at each call, which will get compiled each time instead of reusing the same cached function". However, running the following code only produces one printing side effect:
import jax
def unjitted_loop_body(prev_i):
print("tracing...")
return prev_i + 1
def g_inner_jitted_poorly(x, n):
i = 0
while i < n:
# Don't do this!
i = jax.jit(unjitted_loop_body)(i)
return x + i
g_inner_jitted_poorly(10, 20)
# output:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
tracing...
Out[1]: DeviceArray(30, dtype=int32)
The string "tracing..." is only printed once, and it seems that jit does not trace the function again.
Is this intended? Thanks for any help!

Your example is behaving as expected; I suspect it may have been referencing an older version of the docs.
In the current version of the docs you link to the example is slightly different; rather than jit-compiling the function, it jit-compiles a temporary function wrapper:
def g_inner_jitted_partial(x, n):
i = 0
while i < n:
# Don't do this! each time the partial returns
# a function with different hash
i = jax.jit(partial(unjitted_loop_body))(i)
return x + i
If you do this, it will cause re-compiles every time, because partial(unjitted_loop_body) is a new function object in every loop, and JIT cacheing is based on the id of the function object.
Using the approach in your question is fine, because the function passed to jit is a persistent global function, whose id is the same every iteration:
def g_inner_jitted_normal(x, n):
i = 0
while i < n:
# this is OK, since JAX can find the
# cached, compiled function
i = jax.jit(unjitted_loop_body)(i)
return x + i

Related

Why is my merge sort algorithm not working?

I am implementing the merge sort algorithm in Python. Previously, I have implemented the same algorithm in C, it works fine there, but when I implement in Python, it outputs an unsorted array.
I've already rechecked the algorithm and code, but to my knowledge the code seems to be correct.
I think the issue is related to the scope of variables in Python, but I don't have any clue for how to solve it.
from random import shuffle
# Function to merge the arrays
def merge(a,beg,mid,end):
i = beg
j = mid+1
temp = []
while(i<=mid and j<=end):
if(a[i]<a[j]):
temp.append(a[i])
i += 1
else:
temp.append(a[j])
j += 1
if(i>mid):
while(j<=end):
temp.append(a[j])
j += 1
elif(j>end):
while(i<=mid):
temp.append(a[i])
i += 1
return temp
# Function to divide the arrays recursively
def merge_sort(a,beg,end):
if(beg<end):
mid = int((beg+end)/2)
merge_sort(a,beg,mid)
merge_sort(a,mid+1,end)
a = merge(a,beg,mid,end)
return a
a = [i for i in range(10)]
shuffle(a)
n = len(a)
a = merge_sort(a, 0, n-1)
print(a)
To make it work you need to change merge_sort declaration slightly:
def merge_sort(a,beg,end):
if(beg<end):
mid = int((beg+end)/2)
merge_sort(a,beg,mid)
merge_sort(a,mid+1,end)
a[beg:end+1] = merge(a,beg,mid,end) # < this line changed
return a
Why:
temp is constructed to be no longer than end-beg+1, but a is the initial full array, if you managed to replace all of it, it'd get borked quick. Therefore we take a "slice" of a and replace values in that slice.
Why not:
Your a luckily was not getting replaced, because of Python's inner workings, that is a bit tricky to explain but I'll try.
Every variable in Python is a reference. a is a reference to a list of variables a[i], which are in turn references to a constantant in memory.
When you pass a to a function it makes a new local variable a that points to the same list of variables. That means when you reassign it as a=*** it only changes where a points. You can only pass changes outside either via "slices" or via return statement
Why "slices" work:
Slices are tricky. As I said a points to an array of other variables (basically a[i]), that in turn are references to a constant data in memory, and when you reassign a slice it goes trough the slice element by element and changes where those individual variables are pointing, but as a inside and outside are still pointing to same old elements the changes go through.
Hope it makes sense.
You don't use the results of the recursive merges, so you essentially report the result of the merge of the two unsorted halves.

How to call my function in another function

I am getting an value from the user in getInteger.
I need to get the output from sqInteger in getInteger.
No matter how I set up the parameters or indent the sqInteger function, variable x is undefined.
I added a return line to try and pass the x variable, but that's definitely not helping.
Please help me understand what I'm missing!
def getInteger():
while True:
try:
x = int(input('Enter an integer: '))
except ValueError:
print()
print('That\'s not an integer. Try again.')
continue
else:
return x
print(x)
break
def sqInteger(getInteger, x):
y = x**2
print(y)
Is this the entire code? You need to call the getInteger() function at some point in the code before that loop will begin. You're also not calling function sqInteger() at any point.
Your exception handler will immediately stop evaluating the try block and move down to the except block upon a non-integer being typed into the input. Therefore, you can place a call to the sqInteger() function after the input() function. If the user types a non-integer into the terminal, it will move down to your Exception handler and prompt the user to retry. If they enter an integer, the code will continue to evaluate and run the function sqInteger.
For this, you also do not need to pass getInteger into the sqInteger() function. You are technically allowed to pass functions as parameters in Python but it's not necessary for this and probably out of the scope of this program.
So the following code would be suitable:
def getInteger():
while True:
try:
x = int(input('Enter an integer: '))
# variable 'squared' now receives the return value from the function
squared = sqInteger(x) # call to function sqInteger necessary for this function to be executed
except ValueError:
print('That\'s not an integer. Try again.')
continue
else:
print(x) # if user entered 2, prints 2, not 4
return x # this value is still only what the user input, not the result of sqInteger()
break
def sqInteger(x):
y = x**2
print(y)
return y #you need to return values from functions in order to access it from outside the function
The reason you pass a variable into a function (as a parameter) is to give that function access to that variable. Creating a function creates a local scope for that function so that variables named within that function are in a separate namespace from variables outside that function. This is useful in large programs where many variables might exist and you need to keep them separate.
Because you've separately defined a sqrt function, it does not have access to variables outside of its scope. You need to pass in variables that you'd like it to have access to.
You also need to call functions before they will run. Defining a function only serves to set up the function so that it can be called as one functional unit. It's useful for separating concerns within a program. The ability to call a function is useful because it allows you to separate your code out and only mention a single call to a function rather than having the entire functionality jumbled in with the rest of the code. It also allows for reusability of code.
You can also have access to the result of the squared integer by returning a value and assigning this value to a function call, like such:
# lets say x = 4
squared = sqInteger(x)
def sqInteger(x):
y = x**2
return y
This would NOT work:
x = input("Enter integer") #lets say you enter 3
squared = sqInteger()
print(squared)
def sqInteger():
print(x) # error: x is not defined
return x**2 # error: x is not defined
The function does not have access to outside variables like x. It must be passed these variables as parameters so that you can call this function and set the parameters at will. This is for the sake of modularity in a program. You can pass it all sorts of different integers as parameters and it allows you to have a resuable function for anytime you need to square an integer.
Edit: Sorry this was a mess, I finally fixed all the errors in my explanation though...

asyncio with map&reduce flavor and without flooding the event loop

I am trying to use asyncio in real applications and it doesn't go that
easy, a help of asyncio gurus is needed badly.
Tasks that spawn other tasks without flooding event loop (Success!)
Consider a task like crawling the web starting from some "seeding" web-pages. Each
web-page leads to generation of new downloading tasks in exponential(!)
progression. However we don't want neither to flood the event loop nor to
overload our network. We'd like to control the task flow. This is what I
achieve well with modification of nice Maxime's solution proposed here:
https://mail.python.org/pipermail/python-list/2014-July/687823.html
map & reduce (Fail)
Well, but I'd need as well a very natural thing, kind of map() & reduce()
or functools.reduce() if we are on python3 already. That is, I'd need to
call a "summarizing" function for all the downloading tasks completed on
links from a page. This is where i fail :(
I'd propose an oversimplified but still a nice test to model the use case:
Let's use fibonacci function implementation in its ineffective form.
That is, let the coro_sum() be applied in reduce() and coro_fib be what we apply with
map(). Something like this:
#asyncio.coroutine
def coro_sum(x):
return sum(x)
#asyncio.coroutine
def coro_fib(x):
if x < 2:
return 1
res_coro =
executor_pool.spawn_task_when_arg_list_of_coros_ready(coro=coro_sum,
arg_coro_list=[coro_fib(x - 1), coro_fib(x - 2)])
return res_coro
So that we could run the following tests.
Test #1 on one worker:
executor_pool = ExecutorPool(workers=1)
executor_pool.as_completed( coro_fib(x) for x in range(20) )
Test #2 on two workers:
executor_pool = ExecutorPool(workers=2)
executor_pool.as_completed( coro_fib(x) for x in range(20) )
It would be very important that both each coro_fib() and coro_sum()
invocations are done via a Task on some worker, not just spawned implicitly
and unmanaged!
It would be cool to find asyncio gurus interested in this very natural goal.
Your help and ideas would be very much appreciated.
best regards
Valery
There are multiple ways to compute fibonacci series asynchroniously. First, check that the explosive variant fails in your case:
#asyncio.coroutine
def coro_sum(summands):
return sum(summands)
#asyncio.coroutine
def coro_fib(n):
if n == 0: s = 0
elif n == 1: s = 1
else:
summands, _ = yield from asyncio.wait([coro_fib(n-2), coro_fib(n-1)])
s = yield from coro_sum(f.result() for f in summands)
return s
You could replace summands with:
a = yield from coro_fib(n-2) # don't return until its ready
b = yield from coro_fib(n-1)
s = yield from coro_sum([a, b])
In general, to prevent the exponential growth, you could use asyncio.Queue (synchronization via communication), asyncio.Semaphore (synchonization using mutex) primitives.

IronPython Multiprocessing Module

I have some dlls written in Dotnet which I am using to access thousands of binary files. The data is compartmentalized by directories, so as a performance enhancement I thought of using multiple processes or threads to churn through the files.
I have a function, currently part of the main class (requiring self as an argument), this could easily be refactored to a private method.
My first inclination is to use the Multiprocess module, but that doesn't seem to be available for IronPython.
My next thought was to use Task
def __createThreads(self):
tasks = Array.CreateInstance(Task, 5)
for idx in range(0, 5):
tasks.append(Task.Factory.StartNew(self.__doWork, args=(idx,)))
Task.WaitAll(tasks)
def __doWork(self, idx):
for index in range (0, idx):
print "Thread: %d | Index: %d" % (idx, index)
Or to use Thread
def __createThreads(self):
threads = list()
for idx in range(0, 5):
t = Thread(ThreadStart(self.__doWork))
t.Start()
threads.append(t)
while len(threads) > 0:
time.sleep(.05)
for t in threads:
if(not t.IsAlive):
threads.remove(t)
What I cannot find is a IronPython example of how to pass arguements
Please be aware that your two examples are not exactly equivalent. The task version will only create/use actual concurrent threads when the run-time thinks it is a good idea (unless you specify TaskCreationOptions.LongRunning). You have to decide what works for your use-case.
The easiest way to pass the idx argument to the __doWork function would be using a lambda to capture the value and invocation. (please be aware of scoping issues as discussed in this question which also hints at alternative solutions for introducing an intermediate scope)
tasks.append(Task.Factory.StartNew(lambda idx = idx: self.__doWork(idx)))
As a side-note: You will have to convert your task list to an array in order for Task.WaitAll to be happy.

Python 3.x: Test if generator has elements remaining

When I use a generator in a for loop, it seems to "know", when there are no more elements yielded. Now, I have to use a generator WITHOUT a for loop, and use next() by hand, to get the next element. My problem is, how do I know, if there are no more elements?
I know only: next() raises an exception (StopIteration), if there is nothing left, BUT isn't an exception a little bit too "heavy" for such a simple problem? Isn't there a method like has_next() or so?
The following lines should make clear, what I mean:
#!/usr/bin/python3
# define a list of some objects
bar = ['abc', 123, None, True, 456.789]
# our primitive generator
def foo(bar):
for b in bar:
yield b
# iterate, using the generator above
print('--- TEST A (for loop) ---')
for baz in foo(bar):
print(baz)
print()
# assign a new iterator to a variable
foobar = foo(bar)
print('--- TEST B (try-except) ---')
while True:
try:
print(foobar.__next__())
except StopIteration:
break
print()
# assign a new iterator to a variable
foobar = foo(bar)
# display generator members
print('--- GENERATOR MEMBERS ---')
print(', '.join(dir(foobar)))
The output is as follows:
--- TEST A (for loop) ---
abc
123
None
True
456.789
--- TEST B (try-except) ---
abc
123
None
True
456.789
--- GENERATOR MEMBERS ---
__class__, __delattr__, __doc__, __eq__, __format__, __ge__, __getattribute__, __gt__, __hash__, __init__, __iter__, __le__, __lt__, __name__, __ne__, __new__, __next__, __reduce__, __reduce_ex__, __repr__, __setattr__, __sizeof__, __str__, __subclasshook__, close, gi_code, gi_frame, gi_running, send, throw
Thanks to everybody, and have a nice day! :)
This is a great question. I'll try to show you how we can use Python's introspective abilities and open source to get an answer. We can use the dis module to peek behind the curtain and see how the CPython interpreter implements a for loop over an iterator.
>>> def for_loop(iterable):
... for item in iterable:
... pass # do nothing
...
>>> import dis
>>> dis.dis(for_loop)
2 0 SETUP_LOOP 14 (to 17)
3 LOAD_FAST 0 (iterable)
6 GET_ITER
>> 7 FOR_ITER 6 (to 16)
10 STORE_FAST 1 (item)
3 13 JUMP_ABSOLUTE 7
>> 16 POP_BLOCK
>> 17 LOAD_CONST 0 (None)
20 RETURN_VALUE
The juicy bit appears to be the FOR_ITER opcode. We can't dive any deeper using dis, so let's look up FOR_ITER in the CPython interpreter's source code. If you poke around, you'll find it in Python/ceval.c; you can view it here. Here's the whole thing:
TARGET(FOR_ITER)
/* before: [iter]; after: [iter, iter()] *or* [] */
v = TOP();
x = (*v->ob_type->tp_iternext)(v);
if (x != NULL) {
PUSH(x);
PREDICT(STORE_FAST);
PREDICT(UNPACK_SEQUENCE);
DISPATCH();
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(
PyExc_StopIteration))
break;
PyErr_Clear();
}
/* iterator ended normally */
x = v = POP();
Py_DECREF(v);
JUMPBY(oparg);
DISPATCH();
Do you see how this works? We try to grab an item from the iterator; if we fail, we check what exception was raised. If it's StopIteration, we clear it and consider the iterator exhausted.
So how does a for loop "just know" when an iterator has been exhausted? Answer: it doesn't -- it has to try and grab an element. But why?
Part of the answer is simplicity. Part of the beauty of implementing iterators is that you only have to define one operation: grab the next element. But more importantly, it makes iterators lazy: they'll only produce the values that they absolutely have to.
Finally, if you are really missing this feature, it's trivial to implement it yourself. Here's an example:
class LookaheadIterator:
def __init__(self, iterable):
self.iterator = iter(iterable)
self.buffer = []
def __iter__(self):
return self
def __next__(self):
if self.buffer:
return self.buffer.pop()
else:
return next(self.iterator)
def has_next(self):
if self.buffer:
return True
try:
self.buffer = [next(self.iterator)]
except StopIteration:
return False
else:
return True
x = LookaheadIterator(range(2))
print(x.has_next())
print(next(x))
print(x.has_next())
print(next(x))
print(x.has_next())
print(next(x))
The two statements you wrote deal with finding the end of the generator in exactly the same way. The for-loop simply calls .next() until the StopIteration exception is raised and then it terminates.
http://docs.python.org/tutorial/classes.html#iterators
As such I don't think waiting for the StopIteration exception is a 'heavy' way to deal with the problem, it's the way that generators are designed to be used.
It is not possible to know beforehand about end-of-iterator in the general case, because arbitrary code may have to run to decide about the end. Buffering elements could help revealing things at costs - but this is rarely useful.
In practice the question arises when one wants to take only one or few elements from an iterator for now, but does not want to write that ugly exception handling code (as indicated in the question). Indeed it is non-pythonic to put the concept "StopIteration" into normal application code. And exception handling on python level is rather time-consuming - particularly when it's just about taking one element.
The pythonic way to handle those situations best is either using for .. break [.. else] like:
for x in iterator:
do_something(x)
break
else:
it_was_exhausted()
or using the builtin next() function with default like
x = next(iterator, default_value)
or using iterator helpers e.g. from itertools module for rewiring things like:
max_3_elements = list(itertools.islice(iterator, 3))
Some iterators however expose a "length hint" (PEP424) :
>>> gen = iter(range(3))
>>> gen.__length_hint__()
3
>>> next(gen)
0
>>> gen.__length_hint__()
2
Note: iterator.__next__() should not be used by normal app code. That's why they renamed it from iterator.next() in Python2. And using next() without default is not much better ...
This may not precisely answer your question, but I found my way here looking to elegantly grab a result from a generator without having to write a try: block. A little googling later I figured this out:
def g():
yield 5
result = next(g(), None)
Now result is either 5 or None, depending on how many times you've called next on the iterator, or depending on whether the generator function returned early instead of yielding.
I strongly prefer handling None as an output over raising for "normal" conditions, so dodging the try/catch here is a big win. If the situation calls for it, there's also an easy place to add a default other than None.

Resources