What is the complexity for the following python code? - python-3.x

import math
# for loop includes k/2 (ie. if k/2 = 3.5, then i will go from [1, 3]. 1,2, and 3
def pow(x,k):
y =1
m = math.trunc(k/2)
if k == 0:
return 1
for i in range(1,m+1):
y = y * x
if(k%2 == 0):
return y * y
else:
return y*y*x
I think it uses O(1) memory and O(log k) steps. Is it correct?

The memory complexity is correct thus O(1).
The runtime complexity should be O(k). Each line is done in constant time except the for loop. You iterate in your for loop at most k/2 times. This means that the number of iteration will be linear in k. Note that O(k/2) = O(k).
Also you could see that the complexity of log(k) does not work. If k = 16, your loop will iterate 8 times and by taking the log (base 2) we would only obtain 4 times.

Related

find sequence of elements in numpy array [duplicate]

In Python or NumPy, what is the best way to find out the first occurrence of a subarray?
For example, I have
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
What is the fastest way (run-time-wise) to find out where b occurs in a? I understand for strings this is extremely easy, but what about for a list or numpy ndarray?
Thanks a lot!
[EDITED] I prefer the numpy solution, since from my experience numpy vectorization is much faster than Python list comprehension. Meanwhile, the big array is huge, so I don't want to convert it into a string; that will be (too) long.
I'm assuming you're looking for a numpy-specific solution, rather than a simple list comprehension or for loop. One straightforward approach is to use the rolling window technique to search for windows of the appropriate size.
This approach is simple, works correctly, and is much faster than any pure Python solution. It should be sufficient for many use cases. However, it is not the most efficient approach possible, for a number of reasons. For an approach that is more complicated, but asymptotically optimal in the expected case, see the numba-based rolling hash implementation in norok2's answer.
Here's the rolling_window function:
>>> def rolling_window(a, size):
... shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
... strides = a.strides + (a. strides[-1],)
... return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
...
Then you could do something like
>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
[False, False, False],
[False, False, False],
[ True, True, True],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
To make this really useful, you'd have to reduce it along axis 1 using all:
>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False, True, False, False, False, False], dtype=bool)
Then you could use that however you'd use a boolean array. A simple way to get the index out:
>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])
For lists you could adapt one of these rolling window iterators to use a similar approach.
For very large arrays and subarrays, you could save memory like this:
>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
... hits &= numpy.in1d(windows[:,i], [x])
...
>>> hits
array([False, False, False, True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)
On the other hand, this will probably be somewhat slower.
The following code should work:
[x for x in xrange(len(a)) if a[x:x+len(b)] == b]
Returns the index at which the pattern starts.
(EDITED to include a deeper discussion, better code and more benchmarks)
Summary
For raw speed and efficiency, one can use a Cython or Numba accelerated version (when the input is a Python sequence or a NumPy array, respectively) of one of the classical algorithms.
The recommended approaches are:
find_kmp_cy() for Python sequences (list, tuple, etc.)
find_kmp_nb() for NumPy arrays
Other efficient approaches, are find_rk_cy() and find_rk_nb() which, are more memory efficient but are not guaranteed to run in linear time.
If Cython / Numba are not available, again both find_kmp() and find_rk() are a good all-around solution for most use cases, although in the average case and for Python sequences, the naïve approach, in some form, notably find_pivot(), may be faster. For NumPy arrays, find_conv() (from #Jaime answer) outperforms any non-accelerated naïve approach.
(Full code is below, and here and there.)
Theory
This is a classical problem in computer science that goes by the name of string-searching or string matching problem.
The naive approach, based on two nested loops, has a computational complexity of O(n + m) on average, but worst case is O(n m).
Over the years, a number of alternative approaches have been developed which guarantee a better worst case performances.
Of the classical algorithms, the ones that can be best suited to generic sequences (since they do not rely on an alphabet) are:
the naïve algorithm (basically consisting of two nested loops)
the Knuth–Morris–Pratt (KMP) algorithm
the Rabin-Karp (RK) algorithm
This last algorithm relies on the computation of a rolling hash for its efficiency and therefore may require some additional knowledge of the input for optimal performance.
Eventually, it is best suited for homogeneous data, like for example numeric arrays.
A notable example of numeric arrays in Python is, of course, NumPy arrays.
Remarks
The naïve algorithm, by being so simple, lends itself to different implementations with various degrees of run-time speed in Python.
The other algorithms are less flexible in what can be optimized via language tricks.
Explicit looping in Python may be a speed bottleneck and several tricks can be used to perform the looping outside of the interpreter.
Cython is especially good at speeding up explicit loops for generic Python code.
Numba is especially good at speeding up explicit loops on NumPy arrays.
This is an excellent use-case for generators, so all the code will be using those instead of regular functions.
Python Sequences (list, tuple, etc.)
Based on the Naïve Algorithm
find_loop(), find_loop_cy() and find_loop_nb() which are the explicit-loop only implementation in pure Python, Cython and with Numba JITing respectively. Note the forceobj=True in the Numba version, which is required because we are using Python object inputs.
def find_loop(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
found = True
for j in range(m):
if seq[i + j] != subseq[j]:
found = False
break
if found:
yield i
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_loop_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
for i in range(n - m + 1):
found = True
for j in range(m):
if seq[i + j] != subseq[j]:
found = False
break
if found:
yield i
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
find_all() replaces the inner loop with all() on a comprehension generator
def find_all(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if all(seq[i + j] == subseq[j] for j in range(m)):
yield i
find_slice() replaces the inner loop with direct comparison == after slicing []
def find_slice(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i:i + m] == subseq:
yield i
find_mix() and find_mix2() replaces the inner loop with direct comparison == after slicing [] but includes one or two additional short-circuiting on the first (and last) character which may be faster because slicing with an int is much faster than slicing with a slice().
def find_mix(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i] == subseq[0] and seq[i:i + m] == subseq:
yield i
def find_mix2(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
and seq[i:i + m] == subseq:
yield i
find_pivot() and find_pivot2() replace the outer loop with multiple .index() call using the first item of the sub-sequence, while using slicing for the inner loop, eventually with additional short-circuiting on the last item (the first matches by construction). The multiple .index() calls are wrapped in a index_all() generator (which may be useful on its own).
def index_all(seq, item, start=0, stop=-1):
try:
n = len(seq)
if n > 0:
start %= n
stop %= n
i = start
while True:
i = seq.index(item, i)
if i <= stop:
yield i
i += 1
else:
return
else:
return
except ValueError:
pass
def find_pivot(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
for i in index_all(seq, subseq[0], 0, n - m):
if seq[i:i + m] == subseq:
yield i
def find_pivot2(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
for i in index_all(seq, subseq[0], 0, n - m):
if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
yield i
Based on Knuth–Morris–Pratt (KMP) Algorithm
find_kmp() is a plain Python implementation of the algorithm. Since there is no simple looping or places where one could use slicing with a slice(), there is not much to be done for optimization, except using Cython (Numba would require again forceobj=True which would lead to slow code).
def find_kmp(seq, subseq):
n = len(seq)
m = len(subseq)
# : compute offsets
offsets = [0] * m
j = 1
k = 0
while j < m:
if subseq[j] == subseq[k]:
k += 1
offsets[j] = k
j += 1
else:
if k != 0:
k = offsets[k - 1]
else:
offsets[j] = 0
j += 1
# : find matches
i = j = 0
while i < n:
if seq[i] == subseq[j]:
i += 1
j += 1
if j == m:
yield i - j
j = offsets[j - 1]
elif i < n and seq[i] != subseq[j]:
if j != 0:
j = offsets[j - 1]
else:
i += 1
find_kmp_cy() is Cython implementation of the algorithm where the indices use C int data type, which result in much faster code.
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_kmp_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
# : compute offsets
offsets = [0] * m
cdef Py_ssize_t j = 1
cdef Py_ssize_t k = 0
while j < m:
if subseq[j] == subseq[k]:
k += 1
offsets[j] = k
j += 1
else:
if k != 0:
k = offsets[k - 1]
else:
offsets[j] = 0
j += 1
# : find matches
cdef Py_ssize_t i = 0
j = 0
while i < n:
if seq[i] == subseq[j]:
i += 1
j += 1
if j == m:
yield i - j
j = offsets[j - 1]
elif i < n and seq[i] != subseq[j]:
if j != 0:
j = offsets[j - 1]
else:
i += 1
Based on Rabin-Karp (RK) Algorithm
find_rk() is a pure Python implementation, which relies on Python's hash() for the computation (and comparison) of the hash. Such hash is made rolling by mean of a simple sum(). The roll-over is then computed from the previous hash by subtracting the result of hash() on the just visited item seq[i - 1] and adding up the result of hash() on the newly considered item seq[i + m - 1].
def find_rk(seq, subseq):
n = len(seq)
m = len(subseq)
if seq[:m] == subseq:
yield 0
hash_subseq = sum(hash(x) for x in subseq) # compute hash
curr_hash = sum(hash(x) for x in seq[:m]) # compute hash
for i in range(1, n - m + 1):
curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1]) # update hash
if hash_subseq == curr_hash and seq[i:i + m] == subseq:
yield i
find_rk_cy() is Cython implementation of the algorithm where the indices use the appropriate C data type, which results in much faster code. Note that hash() truncates "the return value based on the bit width of the host machine."
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_rk_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
if seq[:m] == subseq:
yield 0
cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq) # compute hash
cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m]) # compute hash
cdef Py_ssize_t old_item, new_item
for i in range(1, n - m + 1):
old_item = hash(seq[i - 1])
new_item = hash(seq[i + m - 1])
curr_hash += new_item - old_item # update hash
if hash_subseq == curr_hash and seq[i:i + m] == subseq:
yield i
Benchmarks
The above functions are evaluated on two inputs:
random inputs
def gen_input(n, k=2):
return tuple(random.randint(0, k - 1) for _ in range(n))
(almost) worst inputs for the naïve algorithm
def gen_input_worst(n, k=-2):
result = [0] * n
result[k] = 1
return tuple(result)
The subseq has fixed size (32).
Since there are so many alternatives, two separate grouping have been done and some solutions with very small variations and almost identical timings have been omitted (i.e. find_mix2() and find_pivot2()).
For each group both inputs are tested.
For each benchmark the full plot and a zoom on the fastest approach is provided.
Naïve on Random
Naïve on Worst
Other on Random
Other on Worst
(Full code is available here.)
NumPy Arrays
Based on the Naïve Algorithm
find_loop(), find_loop_cy() and find_loop_nb() which are the explicit-loop only implementation in pure Python, Cython and with Numba JITing respectively. The code for the first two are the same as above and hence omitted. find_loop_nb() now enjoys fast JIT compilation. The inner loop has been written in a separate function because it can then be reused for find_rk_nb() (calling Numba functions inside Numba functions does not incur in the function call penalty typical of Python).
#nb.jit
def _is_equal_nb(seq, subseq, m, i):
for j in range(m):
if seq[i + j] != subseq[j]:
return False
return True
#nb.jit
def find_loop_nb(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if _is_equal_nb(seq, subseq, m, i):
yield i
find_all() is the same as above, while find_slice(), find_mix() and find_mix2() are almost identical to the above, the only difference is that seq[i:i + m] == subseq is now the argument of np.all(): np.all(seq[i:i + m] == subseq).
find_pivot() and find_pivot2() share the same ideas as above, except that now uses np.where() instead of index_all() and the need for enclosing the array equality inside an np.all() call.
def find_pivot(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
max_i = n - m
for i in np.where(seq == subseq[0])[0]:
if i > max_i:
return
elif np.all(seq[i:i + m] == subseq):
yield i
def find_pivot2(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
max_i = n - m
for i in np.where(seq == subseq[0])[0]:
if i > max_i:
return
elif seq[i + m - 1] == subseq[m - 1] \
and np.all(seq[i:i + m] == subseq):
yield i
find_rolling() express the looping via a rolling window and the matching is checked with np.all(). This vectorizes all the looping at the expenses of creating large temporary objects, while still substantially appling the naïve algorithm. (The approach is from #senderle answer).
def rolling_window(arr, size):
shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
strides = arr.strides + (arr.strides[-1],)
return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
def find_rolling(seq, subseq):
bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
yield from np.mgrid[0:len(bool_indices)][bool_indices]
find_rolling2() is a slightly more memory efficient variation of the above, where the vectorization is only partial and one explicit looping (along the expected shortest dimension -- the length of subseq) is kept. (The approach is also from #senderle answer).
def find_rolling2(seq, subseq):
windows = rolling_window(seq, len(subseq))
hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
for i, x in enumerate(subseq):
hits &= np.in1d(windows[:, i], [x])
yield from hits.nonzero()[0]
Based on Knuth–Morris–Pratt (KMP) Algorithm
find_kmp() is the same as above, while find_kmp_nb() is a straightforward JIT-compilation of that.
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'
Based on Rabin-Karp (RK) Algorithm
find_rk() is the same as the above, except that again seq[i:i + m] == subseq is enclosed in an np.all() call.
find_rk_nb() is the Numba accelerated version of the above. Uses _is_equal_nb() defined earlier to definitively determine a match, while for the hashing, it uses a Numba accelerated sum_hash_nb() function whose definition is pretty straightforward.
#nb.jit
def sum_hash_nb(arr):
result = 0
for x in arr:
result += hash(x)
return result
#nb.jit
def find_rk_nb(seq, subseq):
n = len(seq)
m = len(subseq)
if _is_equal_nb(seq, subseq, m, 0):
yield 0
hash_subseq = sum_hash_nb(subseq) # compute hash
curr_hash = sum_hash_nb(seq[:m]) # compute hash
for i in range(1, n - m + 1):
curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1]) # update hash
if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
yield i
find_conv() uses a pseudo Rabin-Karp method, where initial candidates are hashed using the np.dot() product and located on the convolution between seq and subseq with np.where(). The approach is pseudo because, while it still uses hashing to identify probable candidates, it is may not be regarded as a rolling hash (it depends on the actual implementation of np.correlate()). Also, it needs to create a temporary array the size of the input. (The approach is from #Jaime answer).
def find_conv(seq, subseq):
target = np.dot(subseq, subseq)
candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
check = candidates[:, np.newaxis] + np.arange(len(subseq))
mask = np.all((np.take(seq, check) == subseq), axis=-1)
yield from candidates[mask]
Benchmarks
Like before, the above functions are evaluated on two inputs:
random inputs
def gen_input(n, k=2):
return np.random.randint(0, k, n)
(almost) worst inputs for the naïve algorithm
def gen_input_worst(n, k=-2):
result = np.zeros(n, dtype=int)
result[k] = 1
return result
The subseq has fixed size (32).
This plots follow the same scheme as before, summarized below for convenience.
Since there are so many alternatives, two separate grouping have been done and some solutions with very small variations and almost identical timings have been omitted (i.e. find_mix2() and find_pivot2()).
For each group both inputs are tested.
For each benchmark the full plot and a zoom on the fastest approach is provided.
Naïve on Random
Naïve on Worst
Other on Random
Other on Worst
(Full code is available here.)
A convolution based approach, that should be more memory efficient than the stride_tricks based approach:
def find_subsequence(seq, subseq):
target = np.dot(subseq, subseq)
candidates = np.where(np.correlate(seq,
subseq, mode='valid') == target)[0]
# some of the candidates entries may be false positives, double check
check = candidates[:, np.newaxis] + np.arange(len(subseq))
mask = np.all((np.take(seq, check) == subseq), axis=-1)
return candidates[mask]
With really big arrays it may not be possible to use a stride_tricks approach, but this one still works:
haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
haystack[idx:idx+100] = needle
In [3]: find_subsequence(haystack, needle)
Out[3]:
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
961100, 973481], dtype=int64)
In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True
In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop
you can call tostring() method to convert an array to string, and then you can use fast string search. this method maybe faster when you have many subarray to check.
import numpy as np
a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize
Another try, but I'm sure there is more pythonic & efficent way to do that ...
def array_match(a, b):
for i in xrange(0, len(a)-len(b)+1):
if a[i:i+len(b)] == b:
return i
return None
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
print array_match(a,b)
1
(This first answer was not in scope of the question, as cdhowie mentionned)
set(a) & set(b) == set(b)
Here is a rather straight-forward option:
def first_subarray(full_array, sub_array):
n = len(full_array)
k = len(sub_array)
matches = np.argwhere([np.all(full_array[start_ix:start_ix+k] == sub_array)
for start_ix in range(0, n-k+1)])
return matches[0]
Then using the original a, b vectors we get:
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
first_subarray(a, b)
Out[44]:
array([1], dtype=int64)
Quick comparison of three of the proposed solutions (average time of 100 iteration for randomly created vectors.):
import time
import collections
import numpy as np
def function_1(seq, sub):
# direct comparison
seq = list(seq)
sub = list(sub)
return [i for i in range(len(seq) - len(sub)) if seq[i:i+len(sub)] == sub]
def function_2(seq, sub):
# Jamie's solution
target = np.dot(sub, sub)
candidates = np.where(np.correlate(seq, sub, mode='valid') == target)[0]
check = candidates[:, np.newaxis] + np.arange(len(sub))
mask = np.all((np.take(seq, check) == sub), axis=-1)
return candidates[mask]
def function_3(seq, sub):
# HYRY solution
return seq.tostring().index(sub.tostring())//seq.itemsize
# --- assessment time performance
N = 100
seq = np.random.choice([0, 1, 2, 3, 4, 5, 6], 3000)
sub = np.array([1, 2, 3])
tim = collections.OrderedDict()
tim.update({function_1: 0.})
tim.update({function_2: 0.})
tim.update({function_3: 0.})
for function in tim.keys():
for _ in range(N):
seq = np.random.choice([0, 1, 2, 3, 4], 3000)
sub = np.array([1, 2, 3])
start = time.time()
function(seq, sub)
end = time.time()
tim[function] += end - start
timer_dict = collections.OrderedDict()
for key, val in tim.items():
timer_dict.update({key.__name__: val / N})
print(timer_dict)
Which would result (on my old machine) in:
OrderedDict([
('function_1', 0.0008518099784851074),
('function_2', 8.157730102539063e-05),
('function_3', 6.124973297119141e-06)
])
First, convert the list to string.
a = ''.join(str(i) for i in a)
b = ''.join(str(i) for i in b)
After converting to string, you can easily find the index of substring with the following string function.
a.index(b)
Cheers!!

Writing user defined function to evaluate the Saha equation with for-loops for an expected output

I am trying to create a function that evaluates the Saha function for certain values of temperature and electron pressure. The question is a little in depth so I will provide as much detail as possible about past code used before this section.
Previous sections code
Evaluating the partition function (part 1):
k= 8.617333262145179e-05
T=10000.
g=1.0
Ca_ion_energies = np.array([6.1131554, 11.871719, 50.91316, 67.2732, 84.34]) #in eV
Ca_partition_values= []
def partfunc_E(chiI,T):
for chiI in Ca_ion_energies:
elem = 0
for i in np.arange(chiI):
elem = elem + (g*np.exp(-(i/(k*T))))
Ca_partition_values.append(elem)
return Ca_partition_values
print(partfunc_E(Ca_ion_energies,T))
Output:
[1.455902590894594, 1.45633321917395, 1.4563345239240013, 1.4563345239240013, 1.4563345239240013]
Evaluating the Boltzmann equation (part 2):
chiI = np.array([6.1131554, 11.871719, 50.91316, 67.2732, 84.34]) #in eV
k= 8.617333262145179e-05
T=10000.
def boltz_E(chiI,T,I,i):
Z_1 = partfunc_E(chiI,T)
ratio = np.exp(-i/(k*T)) / Z_1
return ratio [I-1]
print(Ca_ion_energies)
print("i Fraction in level i for I=1 (neutral)")
print("- -------------------------------------")
for n in range(0,10):
print(n,boltz_E(chiI,10000,1,n))
Output:
[ 6.1131554 11.871719 50.91316 67.2732 84.34 ]
i Fraction in level i for I=1 (neutral)
- -------------------------------------
0 0.6868591389658425
1 0.21522358567610525
2 0.06743914320048579
3 0.021131689732463026
4 0.006621500359539954
5 0.002074811222693332
6 0.0006501308428703751
7 0.0002037149733085943
8 6.383298193775377e-05
9 2.0001718660577703e-05
Question I need help with (and my code so far):
Evaluating the Saha equation (part 3):
The instructions for this section are as follows:
The simplest way to get this ratio is to set 𝑁_𝐼=1 (i.e. the neutral atom) to some value (e.g. unity), evaluate the next ionisation-stage populations successively from the Saha equation in a for loop, and at the end divide them by the sum of all the 𝑁 on the same scale. You will find the numpy np.sum function useful to get the total over all stages. We want temperature T to be 5000K and electron pressure Pe to be 100.0 N/m^2.
FYI: I is the ionisation stage, Z_1 is the partition function from part 1, Z_I is the partition function for stage I+1, Pe is the electron pressure, chiI are the ionisation energies (for Calcium in my code), T is temperature and the function that "fraction" is set equal to is the Saha equation.
It should start something like:
def saha_E(chiI,T,Pe,I):
compute Saha population fraction N_I/N
input: ionisation energies, temperature, electron pressure, ion stage
Compute the partition functions
Loop over each ionisation stage that you have an energy for, computing the fraction via the saha equation. Note that the first stage should be set to 1.
Divide each stage by the total
Return the fraction of the requested stage
My code attempt:
k= 8.617333262145179e-05
T=10000.
g=1.0
Ca_ion_energies = np.array([6.1131554, 11.871719, 50.91316, 67.2732, 84.34])
N_I = 1
h = 6.626e-34
m = 9.11e-31
fractions = []
fraction_sum = []
def saha_E(chiI,T,Pe,I):
Z_1 = partfunc_E(chiI,T)
Z_I = partfunc_E(chiI+1,T)
for I in Ca_ion_energies:
fraction = (N_I*(Z_I/Z_1)*((2*k*T)/((h**3)*Pe))*((2*np.pi*m*k*T)**(3/2))*np.exp(-I/(k*T)))
fractions.append(fraction)
fraction_sum.append(np.sum(fractions))
for i in fractions:
i/fraction_sum
return fraction
print("For ionisation energies (in eV) of:",chiI)
print()
print("I Fraction in stage I")
print("- -------------------")
for I in range(0,6):
print(I,saha_E(chiI,5000,100.0,I))
I am instructed also that the output should be something similar to:
For ionisation energies (in eV) of: [ 6.11 11.87 50.91 67.27 84.34]
I Fraction in stage I
- -------------------
1 0.999998720736
2 1.27926351211e-06
3 7.29993420039e-52
4 1.3474665329e-113
5 1.54848994685e-192
Firstly, I don't think my code is correct but it is the best I can do which is why I need some help, but also, this code is giving me the following error:
TypeError: unsupported operand type(s) for /: 'list' and 'list'
If my code is totally wrong please tell me as I have spent so much time trying to figure this out already.
Edit
This question is still not completely answered, please keep commenting!
If I understood your problem well, my approach is to calculate the "fractions" and "fractions sums" in a single loop on the various energies, and normalize only once we are outside the loop.
Also, careful with the scope of your code. I pushed some variables you declared outside of the function inside of it because there is no reason to keep them alive outside of the function's scope.
Careful also not to use the same variable twice. Your function takes a I argument but then has a I variable in a for loop.
As said in the chat, you want to write dosctrings and comments so that you know where you are going even before touching any code. Here is a base to complete:
import numpy as np
# Constants.
k = 8.617333262145179e-05
g = 1.0
h = 6.626e-34
m = 9.11e-31
Ca_ion_energies = np.array([6.1131554, 11.871719, 50.91316, 67.2732, 84.34]) # in eV.
# Partition function.
def partfunc_E(chiI, T):
"""This function returns the partition of blablabla.
args:
------
:chiI: (array or list) the energy levels of a chosen ion.
:T: (float) the temperature at which kT will be calculated."""
Ca_partition_values = []
for energy_level in chiI: # For each energy level.
elem = 0
for i in np.arange(energy_level): # From 0 to current energy level.
elem += g*np.exp(-(i/(k*T)))
Ca_partition_values.append(elem)
return np.array(Ca_partition_values) # Conversion to numpy array to support operations later.
print(partfunc_E(Ca_ion_energies, T=10000))
# Boltzmann equation.
def boltz_E(chiI, T, I, i):
Z_1 = partfunc_E(chiI, T)
ratio = np.exp(-i/(k*T)) / Z_1
return ratio[I-1]
print(Ca_ion_energies)
print("i Fraction in level i for I=1 (neutral)")
print("- -------------------------------------")
for n in range(0,10):
print(n, boltz_E(Ca_ion_energies, T=10000, I=1, i=n))
# Saha equation.
def saha_E(chiI, T, Pe, i):
p = partfunc_E(chiI, T)
Z_ratios = np.array([p[n]/p[0] for n in range(len(chiI))])
fractions = []
fractions_sum = []
for n, I in enumerate(chiI):
fraction = Z_ratios[n]*((2*k*T)/((h**3)*Pe))*((2*np.pi*m*k*T)**(3/2))*np.exp(-I/(k*T))
fractions.append(fraction)
fractions_sum.append(np.sum(fractions))
# Let's normalize the array before returning it.
fractions = np.divide(fractions, fractions_sum)
return fractions[i]
print("For ionisation energies (in eV) of:", Ca_ion_energies)
print()
print("I Fraction in stage n")
print("- -------------------")
for n in range(0, 4):
print(n, saha_E(Ca_ion_energies, T=5000, Pe=100.0, i=n))

Optimization of CodeWars Python 3.6 code: Integers: Recreation One

I need help optimizing my python 3.6 code for the CodeWars Integers: Recreation One Kata.
We are given a range of numbers and we have to return the number and the sum of the divisors squared that is a square itself.
"Divisors of 42 are : 1, 2, 3, 6, 7, 14, 21, 42. These divisors squared are: 1, 4, 9, 36, 49, 196, 441, 1764. The sum of the squared divisors is 2500 which is 50 * 50, a square!
Given two integers m, n (1 <= m <= n) we want to find all integers between m and n whose sum of squared divisors is itself a square. 42 is such a number."
My code works for individual tests, but it times out when submitting:
def list_squared(m, n):
sqsq = []
for i in range(m, n):
divisors = [j**2 for j in range(1, i+1) if i % j == 0]
sq_divs = sum(divisors)
sq = sq_divs ** (1/2)
if int(sq) ** 2 == sq_divs:
sqsq.append([i, sq_divs])
return sqsq
You can reduce complexity of loop in list comprehension from O(N) to O(Log((N)) by setting the max range to sqrt(num)+1 instead of num.
By looping from 1 to sqrt(num)+1, we can conclude that if i (current item in the loop) is a factor of num then num divided by i must be another one.
Eg: 2 is a factor of 10, so is 5 (10/2)
The following code passes all the tests:
import math
def list_squared(m, n):
result = []
for num in range(m, n + 1):
divisors = set()
for i in range(1, int(math.sqrt(num)+1)):
if num % i == 0:
divisors.add(i**2)
divisors.add(int(num/i)**2)
total = sum(divisors)
sr = math.sqrt(total)
if sr - math.floor(sr) == 0:
result.append([num, total])
return result
It's more the math issue. Two maximum divisors for i is i itself and i/2. So you can speed up the code twice just using i // 2 + 1 as the range stop instead of i + 1. Just don't forget to increase sq_divs for i ** 2.
You may want to get some tiny performance improvements excluding sq variable and sq_divs ** (1/2).
BTW you should use n+1 stop in the first range.
def list_squared(m, n):
sqsq = []
for i in range(m, n+1):
divisors = [j * j for j in range(1, i // 2 + 1 #speed up twice
) if i % j == 0]
sq_divs = sum(divisors)
sq_divs += i * i #add i as divisor
if ((sq_divs) ** 0.5) % 1 == 0: #tiny speed up here
sqsq.append([i, sq_divs])
return sqsq
UPD: I've tried the Kata and it's still timeout. So we need even more math! If i could be divided by j then it's also could be divided by i/j so we can use sqrt(i) (int(math.sqrt(i)) + 1)) as the range stop. if i % j == 0 then append j * j to divisors array. AND if i / j != j then append (i / j) ** 2.

simpson integration on python

I am trying to integrate numerically using simpson integration rule for f(x) = 2x from 0 to 1, but keep getting a large error. The desired output is 1 but, the output from python is 1.334. Can someone help me find a solution to this problem?
thank you.
import numpy as np
def f(x):
return 2*x
def simpson(f,a,b,n):
x = np.linspace(a,b,n)
dx = (b-a)/n
for i in np.arange(1,n):
if i % 2 != 0:
y = 4*f(x)
elif i % 2 == 0:
y = 2*f(x)
return (f(a)+sum(y)+f(x)[-1])*dx/3
a = 0
b = 1
n = 1000
ans = simpson(f,a,b,n)
print(ans)
There is everything wrong. x is an array, everytime you call f(x), you are evaluating the function over the whole array. As n is even and n-1 odd, the y in the last loop is 4*f(x) and from its sum something is computed
Then n is the number of segments. The number of points is n+1. A correct implementation is
def simpson(f,a,b,n):
x = np.linspace(a,b,n+1)
y = f(x)
dx = x[1]-x[0]
return (y[0]+4*sum(y[1::2])+2*sum(y[2:-1:2])+y[-1])*dx/3
simpson(lambda x:2*x, 0, 1, 1000)
which then correctly returns 1.000. You might want to add a test if n is even, and increase it by one if that is not the case.
If you really want to keep the loop, you need to actually accumulate the sum inside the loop.
def simpson(f,a,b,n):
dx = (b-a)/n;
res = 0;
for i in range(1,n): res += f(a+i*dx)*(2 if i%2==0 else 4);
return (f(a)+f(b) + res)*dx/3;
simpson(lambda x:2*x, 0, 1, 1000)
But loops are generally slower than vectorized operations, so if you use numpy, use vectorized operations. Or just use directly scipy.integrate.simps.

Speed up bitstring/bit operations in Python?

I wrote a prime number generator using Sieve of Eratosthenes and Python 3.1. The code runs correctly and gracefully at 0.32 seconds on ideone.com to generate prime numbers up to 1,000,000.
# from bitstring import BitString
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
flags = [False, False] + [True] * (limit - 2)
# flags = BitString(limit)
# Step through all the odd numbers
for i in range(3, limit, 2):
if flags[i] is False:
# if flags[i] is True:
continue
yield i
# Exclude further multiples of the current prime number
if i <= sub_limit:
for j in range(i*3, limit, i<<1):
flags[j] = False
# flags[j] = True
The problem is, I run out of memory when I try to generate numbers up to 1,000,000,000.
flags = [False, False] + [True] * (limit - 2)
MemoryError
As you can imagine, allocating 1 billion boolean values (1 byte 4 or 8 bytes (see comment) each in Python) is really not feasible, so I looked into bitstring. I figured, using 1 bit for each flag would be much more memory-efficient. However, the program's performance dropped drastically - 24 seconds runtime, for prime number up to 1,000,000. This is probably due to the internal implementation of bitstring.
You can comment/uncomment the three lines to see what I changed to use BitString, as the code snippet above.
My question is, is there a way to speed up my program, with or without bitstring?
Edit: Please test the code yourself before posting. I can't accept answers that run slower than my existing code, naturally.
Edit again:
I've compiled a list of benchmarks on my machine.
There are a couple of small optimizations for your version. By reversing the roles of True and False, you can change "if flags[i] is False:" to "if flags[i]:". And the starting value for the second range statement can be i*i instead of i*3. Your original version takes 0.166 seconds on my system. With those changes, the version below takes 0.156 seconds on my system.
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
flags = [True, True] + [False] * (limit - 2)
# Step through all the odd numbers
for i in range(3, limit, 2):
if flags[i]:
continue
yield i
# Exclude further multiples of the current prime number
if i <= sub_limit:
for j in range(i*i, limit, i<<1):
flags[j] = True
This doesn't help your memory issue, though.
Moving into the world of C extensions, I used the development version of gmpy. (Disclaimer: I'm one of the maintainers.) The development version is called gmpy2 and supports mutable integers called xmpz. Using gmpy2 and the following code, I have a running time of 0.140 seconds. Running time for a limit of 1,000,000,000 is 158 seconds.
import gmpy2
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
# Actual number is 2*bit_position + 1.
oddnums = gmpy2.xmpz(1)
current = 0
while True:
current += 1
current = oddnums.bit_scan0(current)
prime = 2 * current + 1
if prime > limit:
break
yield prime
# Exclude further multiples of the current prime number
if prime <= sub_limit:
for j in range(2*current*(current+1), limit>>1, prime):
oddnums.bit_set(j)
Pushing optimizations, and sacrificing clarity, I get running times of 0.107 and 123 seconds with the following code:
import gmpy2
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
# Actual number is 2*bit_position + 1.
oddnums = gmpy2.xmpz(1)
f_set = oddnums.bit_set
f_scan0 = oddnums.bit_scan0
current = 0
while True:
current += 1
current = f_scan0(current)
prime = 2 * current + 1
if prime > limit:
break
yield prime
# Exclude further multiples of the current prime number
if prime <= sub_limit:
list(map(f_set,range(2*current*(current+1), limit>>1, prime)))
Edit: Based on this exercise, I modified gmpy2 to accept xmpz.bit_set(iterator). Using the following code, the run time for all primes less 1,000,000,000 is 56 seconds for Python 2.7 and 74 seconds for Python 3.2. (As noted in the comments, xrange is faster than range.)
import gmpy2
try:
range = xrange
except NameError:
pass
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
oddnums = gmpy2.xmpz(1)
f_scan0 = oddnums.bit_scan0
current = 0
while True:
current += 1
current = f_scan0(current)
prime = 2 * current + 1
if prime > limit:
break
yield prime
if prime <= sub_limit:
oddnums.bit_set(iter(range(2*current*(current+1), limit>>1, prime)))
Edit #2: One more try! I modified gmpy2 to accept xmpz.bit_set(slice). Using the following code, the run time for all primes less 1,000,000,000 is about 40 seconds for both Python 2.7 and Python 3.2.
from __future__ import print_function
import time
import gmpy2
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
flags = gmpy2.xmpz(1)
# pre-allocate the total length
flags.bit_set((limit>>1)+1)
f_scan0 = flags.bit_scan0
current = 0
while True:
current += 1
current = f_scan0(current)
prime = 2 * current + 1
if prime > limit:
break
yield prime
if prime <= sub_limit:
flags.bit_set(slice(2*current*(current+1), limit>>1, prime))
start = time.time()
result = list(prime_numbers(1000000000))
print(time.time() - start)
Edit #3: I've updated gmpy2 to properly support slicing at the bit level of an xmpz. No change in performance but a much nice API. I have done a little tweaking and I've got the time down to about 37 seconds. (See Edit #4 to changes in gmpy2 2.0.0b1.)
from __future__ import print_function
import time
import gmpy2
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
sub_limit = int(limit**0.5)
flags = gmpy2.xmpz(1)
flags[(limit>>1)+1] = True
f_scan0 = flags.bit_scan0
current = 0
prime = 2
while prime <= sub_limit:
yield prime
current += 1
current = f_scan0(current)
prime = 2 * current + 1
flags[2*current*(current+1):limit>>1:prime] = True
while prime <= limit:
yield prime
current += 1
current = f_scan0(current)
prime = 2 * current + 1
start = time.time()
result = list(prime_numbers(1000000000))
print(time.time() - start)
Edit #4: I made some changes in gmpy2 2.0.0b1 that break the previous example. gmpy2 no longer treats True as a special value that provides an infinite source of 1-bits. -1 should be used instead.
from __future__ import print_function
import time
import gmpy2
def prime_numbers(limit=1000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
sub_limit = int(limit**0.5)
flags = gmpy2.xmpz(1)
flags[(limit>>1)+1] = 1
f_scan0 = flags.bit_scan0
current = 0
prime = 2
while prime <= sub_limit:
yield prime
current += 1
current = f_scan0(current)
prime = 2 * current + 1
flags[2*current*(current+1):limit>>1:prime] = -1
while prime <= limit:
yield prime
current += 1
current = f_scan0(current)
prime = 2 * current + 1
start = time.time()
result = list(prime_numbers(1000000000))
print(time.time() - start)
Edit #5: I've made some enhancements to gmpy2 2.0.0b2. You can now iterate over all the bits that are either set or clear. Running time has improved by ~30%.
from __future__ import print_function
import time
import gmpy2
def sieve(limit=1000000):
'''Returns a generator that yields the prime numbers up to limit.'''
# Increment by 1 to account for the fact that slices do not include
# the last index value but we do want to include the last value for
# calculating a list of primes.
sieve_limit = gmpy2.isqrt(limit) + 1
limit += 1
# Mark bit positions 0 and 1 as not prime.
bitmap = gmpy2.xmpz(3)
# Process 2 separately. This allows us to use p+p for the step size
# when sieving the remaining primes.
bitmap[4 : limit : 2] = -1
# Sieve the remaining primes.
for p in bitmap.iter_clear(3, sieve_limit):
bitmap[p*p : limit : p+p] = -1
return bitmap.iter_clear(2, limit)
if __name__ == "__main__":
start = time.time()
result = list(sieve(1000000000))
print(time.time() - start)
print(len(result))
OK, so this is my second answer, but as speed is of the essence I thought that I had to mention the bitarray module - even though it's bitstring's nemesis :). It's ideally suited to this case as not only is it a C extension (and so faster than pure Python has a hope of being), but it also supports slice assignments.
I haven't even tried to optimise this, I just rewrote the bitstring version. On my machine I get 0.16 seconds for primes under a million.
For a billion, it runs perfectly well and completes in 2 minutes 31 seconds.
import bitarray
def prime_bitarray(limit=1000000):
yield 2
flags = bitarray.bitarray(limit)
flags.setall(False)
sub_limit = int(limit**0.5)
for i in range(3, limit, 2):
if not flags[i]:
yield i
if i <= sub_limit:
flags[3*i:limit:i*2] = True
Okay, here's a (near complete) comprehensive benchmarking I've done tonight to see which code runs the fastest. Hopefully someone will find this list useful. I omitted anything that takes more than 30 seconds to complete on my machine.
I would like to thank everyone that put in an input. I've gained a lot of insight from your efforts, and I hope you have too.
My machine: AMD ZM-86, 2.40 Ghz Dual-Core, with 4GB of RAM. This is a HP Touchsmart Tx2 laptop. Note that while I may have linked to a pastebin, I benchmarked all of the following on my own machine.
I will add the gmpy2 benchmark once I am able to build it.
All of the benchmarks are tested in Python 2.6 x86
Returning a list of prime numbers n up to 1,000,000: (Using Python
generators)
Sebastian's numpy generator version (updated) - 121 ms #
Mark's Sieve + Wheel - 154 ms
Robert's version with slicing - 159 ms
My improved version with slicing
- 205 ms
Numpy generator with enumerate - 249 ms #
Mark's Basic Sieve - 317 ms
casevh's improvement on my original
solution - 343 ms
My modified numpy generator solution - 407 ms
My original method in the
question - 409 ms
Bitarray Solution - 414 ms #
Pure Python with bytearray - 1394 ms #
Scott's BitString solution - 6659
ms #
'#' means this method is capable of generating up to n < 1,000,000,000 on
my machine setup.
In addition, if you don't need the
generator and just want the whole list
at once:
numpy solution from RosettaCode -
32 ms #
(The numpy solution is also capable of generating up to 1 billion, which took 61.6259 seconds. I suspect the memory was swapped once, hence the double time.)
Related question: Fastest way to list all primes below N in python.
Hi, i am too looking for a code in Python to generate primes up to 10^9 as fast as i can, which is difficult because of the memory problem. up to now i came up with this to generate primes up to 10^6 & 10^7 (clocking 0,171s & 1,764s respectively on my old machine), which seems to be slightly faster (at least in my computer) than "My improved version with slicing" from previous post, probably because i use n//i-i +1 (see below) instead of the analogous len(flags[i2::i<<1]) in the other code. please correct me if i am wrong. Any suggestions for improvement are very welcome.
def primes(n):
""" Returns a list of primes < n """
sieve = [True] * n
for i in xrange(3,int(n**0.5)+1,2):
if sieve[i]:
sieve[i*i::2*i]=[False]*((n-i*i-1)/(2*i)+1)
return [2] + [i for i in xrange(3,n,2) if sieve[i]]
In one of his codes Xavier uses flags[i2::i<<1] and len(flags[i2::i<<1]). I computed the size explicitly, using the fact that between i*i..n we have (n-i*i)//2*i multiples of 2*i because we want to count i*i also we sum 1 so len(sieve[i*i::2*i]) equals (n-i*i)//(2*i) +1. This makes the code faster. A basic generator based on the code above would be:
def primesgen(n):
""" Generates all primes <= n """
sieve = [True] * n
yield 2
for i in xrange(3,int(n**0.5)+1,2):
if sieve[i]:
yield i
sieve[i*i::2*i] = [False]*((n-i*i-1)/(2*i)+1)
for i in xrange(i+2,n,2):
if sieve[i]: yield i
with a bit of modification one can write a slightly slower version of the code above that starts with a sieve half of the size sieve = [True] * (n//2) and works for the same n. not sure how that will reflect in the memory issue. As an example of implementation here is the
modified version of the numpy rosetta code (maybe faster) starting with a sieve half of the size.
import numpy
def primesfrom3to(n):
""" Returns a array of primes, 3 <= p < n """
sieve = numpy.ones(n/2, dtype=numpy.bool)
for i in xrange(3,int(n**0.5)+1,2):
if sieve[i/2]: sieve[i*i/2::i] = False
return 2*numpy.nonzero(sieve)[0][1::]+1
A Faster & more memory-wise generator would be:
import numpy
def primesgen1(n):
""" Input n>=6, Generates all primes < n """
sieve1 = numpy.ones(n/6+1, dtype=numpy.bool)
sieve5 = numpy.ones(n/6 , dtype=numpy.bool)
sieve1[0] = False
yield 2; yield 3;
for i in xrange(int(n**0.5)/6+1):
if sieve1[i]:
k=6*i+1; yield k;
sieve1[ ((k*k)/6) ::k] = False
sieve5[(k*k+4*k)/6::k] = False
if sieve5[i]:
k=6*i+5; yield k;
sieve1[ ((k*k)/6) ::k] = False
sieve5[(k*k+2*k)/6::k] = False
for i in xrange(i+1,n/6):
if sieve1[i]: yield 6*i+1
if sieve5[i]: yield 6*i+5
if n%6 > 1:
if sieve1[i+1]: yield 6*(i+1)+1
or with a bit more code:
import numpy
def primesgen(n):
""" Input n>=30, Generates all primes < n """
size = n/30 + 1
sieve01 = numpy.ones(size, dtype=numpy.bool)
sieve07 = numpy.ones(size, dtype=numpy.bool)
sieve11 = numpy.ones(size, dtype=numpy.bool)
sieve13 = numpy.ones(size, dtype=numpy.bool)
sieve17 = numpy.ones(size, dtype=numpy.bool)
sieve19 = numpy.ones(size, dtype=numpy.bool)
sieve23 = numpy.ones(size, dtype=numpy.bool)
sieve29 = numpy.ones(size, dtype=numpy.bool)
sieve01[0] = False
yield 2; yield 3; yield 5;
for i in xrange(int(n**0.5)/30+1):
if sieve01[i]:
k=30*i+1; yield k;
sieve01[ (k*k)/30::k] = False
sieve07[(k*k+ 6*k)/30::k] = False
sieve11[(k*k+10*k)/30::k] = False
sieve13[(k*k+12*k)/30::k] = False
sieve17[(k*k+16*k)/30::k] = False
sieve19[(k*k+18*k)/30::k] = False
sieve23[(k*k+22*k)/30::k] = False
sieve29[(k*k+28*k)/30::k] = False
if sieve07[i]:
k=30*i+7; yield k;
sieve01[(k*k+ 6*k)/30::k] = False
sieve07[(k*k+24*k)/30::k] = False
sieve11[(k*k+16*k)/30::k] = False
sieve13[(k*k+12*k)/30::k] = False
sieve17[(k*k+ 4*k)/30::k] = False
sieve19[ (k*k)/30::k] = False
sieve23[(k*k+22*k)/30::k] = False
sieve29[(k*k+10*k)/30::k] = False
if sieve11[i]:
k=30*i+11; yield k;
sieve01[ (k*k)/30::k] = False
sieve07[(k*k+ 6*k)/30::k] = False
sieve11[(k*k+20*k)/30::k] = False
sieve13[(k*k+12*k)/30::k] = False
sieve17[(k*k+26*k)/30::k] = False
sieve19[(k*k+18*k)/30::k] = False
sieve23[(k*k+ 2*k)/30::k] = False
sieve29[(k*k+ 8*k)/30::k] = False
if sieve13[i]:
k=30*i+13; yield k;
sieve01[(k*k+24*k)/30::k] = False
sieve07[(k*k+ 6*k)/30::k] = False
sieve11[(k*k+ 4*k)/30::k] = False
sieve13[(k*k+18*k)/30::k] = False
sieve17[(k*k+16*k)/30::k] = False
sieve19[ (k*k)/30::k] = False
sieve23[(k*k+28*k)/30::k] = False
sieve29[(k*k+10*k)/30::k] = False
if sieve17[i]:
k=30*i+17; yield k;
sieve01[(k*k+ 6*k)/30::k] = False
sieve07[(k*k+24*k)/30::k] = False
sieve11[(k*k+26*k)/30::k] = False
sieve13[(k*k+12*k)/30::k] = False
sieve17[(k*k+14*k)/30::k] = False
sieve19[ (k*k)/30::k] = False
sieve23[(k*k+ 2*k)/30::k] = False
sieve29[(k*k+20*k)/30::k] = False
if sieve19[i]:
k=30*i+19; yield k;
sieve01[ (k*k)/30::k] = False
sieve07[(k*k+24*k)/30::k] = False
sieve11[(k*k+10*k)/30::k] = False
sieve13[(k*k+18*k)/30::k] = False
sieve17[(k*k+ 4*k)/30::k] = False
sieve19[(k*k+12*k)/30::k] = False
sieve23[(k*k+28*k)/30::k] = False
sieve29[(k*k+22*k)/30::k] = False
if sieve23[i]:
k=30*i+23; yield k;
sieve01[(k*k+24*k)/30::k] = False
sieve07[(k*k+ 6*k)/30::k] = False
sieve11[(k*k+14*k)/30::k] = False
sieve13[(k*k+18*k)/30::k] = False
sieve17[(k*k+26*k)/30::k] = False
sieve19[ (k*k)/30::k] = False
sieve23[(k*k+ 8*k)/30::k] = False
sieve29[(k*k+20*k)/30::k] = False
if sieve29[i]:
k=30*i+29; yield k;
sieve01[ (k*k)/30::k] = False
sieve07[(k*k+24*k)/30::k] = False
sieve11[(k*k+20*k)/30::k] = False
sieve13[(k*k+18*k)/30::k] = False
sieve17[(k*k+14*k)/30::k] = False
sieve19[(k*k+12*k)/30::k] = False
sieve23[(k*k+ 8*k)/30::k] = False
sieve29[(k*k+ 2*k)/30::k] = False
for i in xrange(i+1,n/30):
if sieve01[i]: yield 30*i+1
if sieve07[i]: yield 30*i+7
if sieve11[i]: yield 30*i+11
if sieve13[i]: yield 30*i+13
if sieve17[i]: yield 30*i+17
if sieve19[i]: yield 30*i+19
if sieve23[i]: yield 30*i+23
if sieve29[i]: yield 30*i+29
i += 1
mod30 = n%30
if mod30 > 1:
if sieve01[i]: yield 30*i+1
if mod30 > 7:
if sieve07[i]: yield 30*i+7
if mod30 > 11:
if sieve11[i]: yield 30*i+11
if mod30 > 13:
if sieve13[i]: yield 30*i+13
if mod30 > 17:
if sieve17[i]: yield 30*i+17
if mod30 > 19:
if sieve19[i]: yield 30*i+19
if mod30 > 23:
if sieve23[i]: yield 30*i+23
Ps: If you have any suggestions about how to speed up this code, anything from changing the order of operations to pre-computing anything, please comment.
Here's a version that I wrote a while back; it might be interesting to compare with yours for speed. It doesn't do anything about the space problems, though (in fact, they're probably worse than with your version).
from math import sqrt
def basicSieve(n):
"""Given a positive integer n, generate the primes < n."""
s = [1]*n
for p in xrange(2, 1+int(sqrt(n-1))):
if s[p]:
a = p*p
s[a::p] = [0] * -((a-n)//p)
for p in xrange(2, n):
if s[p]:
yield p
I have faster versions, using a wheel, but they're much more complicated. Original source is here.
Okay, here's the version using a wheel. wheelSieve is the main entry point.
from math import sqrt
from bisect import bisect_left
def basicSieve(n):
"""Given a positive integer n, generate the primes < n."""
s = [1]*n
for p in xrange(2, 1+int(sqrt(n-1))):
if s[p]:
a = p*p
s[a::p] = [0] * -((a-n)//p)
for p in xrange(2, n):
if s[p]:
yield p
class Wheel(object):
"""Class representing a wheel.
Attributes:
primelimit -> wheel covers primes < primelimit.
For example, given a primelimit of 6
the wheel primes are 2, 3, and 5.
primes -> list of primes less than primelimit
size -> product of the primes in primes; the modulus of the wheel
units -> list of units modulo size
phi -> number of units
"""
def __init__(self, primelimit):
self.primelimit = primelimit
self.primes = list(basicSieve(primelimit))
# compute the size of the wheel
size = 1
for p in self.primes:
size *= p
self.size = size
# find the units by sieving
units = [1] * self.size
for p in self.primes:
units[::p] = [0]*(self.size//p)
self.units = [i for i in xrange(self.size) if units[i]]
# number of units
self.phi = len(self.units)
def to_index(self, n):
"""Compute alpha(n), where alpha is an order preserving map
from the set of units modulo self.size to the nonnegative integers.
If n is not a unit, the index of the first unit greater than n
is given."""
return bisect_left(self.units, n % self.size) + n // self.size * self.phi
def from_index(self, i):
"""Inverse of to_index."""
return self.units[i % self.phi] + i // self.phi * self.size
def wheelSieveInner(n, wheel):
"""Given a positive integer n and a wheel, find the wheel indices of
all primes that are less than n, and that are units modulo the
wheel modulus.
"""
# renaming to avoid the overhead of attribute lookups
U = wheel.units
wS = wheel.size
# inverse of unit map
UI = dict((u, i) for i, u in enumerate(U))
nU = len(wheel.units)
sqroot = 1+int(sqrt(n-1)) # ceiling of square root of n
# corresponding index (index of next unit up)
sqrti = bisect_left(U, sqroot % wS) + sqroot//wS*nU
upper = bisect_left(U, n % wS) + n//wS*nU
ind2 = bisect_left(U, 2 % wS) + 2//wS*nU
s = [1]*upper
for i in xrange(ind2, sqrti):
if s[i]:
q = i//nU
u = U[i%nU]
p = q*wS+u
u2 = u*u
aq, au = (p+u)*q+u2//wS, u2%wS
wp = p * nU
for v in U:
# eliminate entries corresponding to integers
# congruent to p*v modulo p*wS
uvr = u*v%wS
m = aq + (au > uvr)
bot = (m + (q*v + u*v//wS - m) % p) * nU + UI[uvr]
s[bot::wp] = [0]*-((bot-upper)//wp)
return s
def wheelSieve(n, wheel=Wheel(10)):
"""Given a positive integer n, generate the list of primes <= n."""
n += 1
wS = wheel.size
U = wheel.units
nU = len(wheel.units)
s = wheelSieveInner(n, wheel)
return (wheel.primes[:bisect_left(wheel.primes, n)] +
[p//nU*wS + U[p%nU] for p in xrange(bisect_left(U, 2 % wS)
+ 2//wS*nU, len(s)) if s[p]])
As to what a wheel is: well, you know that (apart from 2), all primes are odd, so most sieves miss out all the even numbers. Similarly, you can go a bit further and notice that all primes (except 2 and 3) are congruent to 1 or 5 modulo 6 (== 2 * 3), so you can get away with only storing entries for those numbers in your sieve. The next step up is to note that all primes (except 2, 3 and 5) are congruent to one of 1, 7, 11, 13, 17, 19, 23, 29 (modulo 30) (here 30 == 2*3*5), and so on.
One speed improvement you can make using bitstring is to use the 'set' method when you exclude multiples of the current number.
So the vital section becomes
for i in range(3, limit, 2):
if flags[i]:
yield i
if i <= sub_limit:
flags.set(1, range(i*3, limit, i*2))
On my machine this runs about 3 times faster than the original.
My other thought was to use the bitstring to represent only the odd numbers. You could then find the unset bits using flags.findall([0]) which returns a generator. Not sure if that would be much faster (finding bit patterns isn't terribly easy to do efficiently).
[Full disclosure: I wrote the bitstring module, so I've got some pride at stake here!]
As a comparison I've also taken the guts out of the bitstring method so that it's doing it in the same way, but without range checking etc. I think this gives a reasonable lower limit for pure Python that works for a billion elements (without changing the algorithm, which I think is cheating!)
def prime_pure(limit=1000000):
yield 2
flags = bytearray((limit + 7) // 8)
sub_limit = int(limit**0.5)
for i in xrange(3, limit, 2):
byte, bit = divmod(i, 8)
if not flags[byte] & (128 >> bit):
yield i
if i <= sub_limit:
for j in xrange(i*3, limit, i*2):
byte, bit = divmod(j, 8)
flags[byte] |= (128 >> bit)
On my machine this runs in about 0.62 seconds for a million elements, which means it's about a quarter of the speed of the bitarray answer. I think that's quite reasonable for pure Python.
Python's integer types can be of arbitrary size, so you shouldn't need a clever bitstring library to represent a set of bits, just a single number.
Here's the code. It handles a limit of 1,000,000 with ease, and can even handle 10,000,000 without complaining too much:
def multiples_of(n, step, limit):
bits = 1 << n
old_bits = bits
max = 1 << limit
while old_bits < max:
old_bits = bits
bits += bits << step
step *= 2
return old_bits
def prime_numbers(limit=10000000):
'''Prime number generator. Yields the series
2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
using Sieve of Eratosthenes.
'''
yield 2
sub_limit = int(limit**0.5)
flags = ((1 << (limit - 2)) - 1) << 2
# Step through all the odd numbers
for i in xrange(3, limit, 2):
if not (flags & (1 << i)):
continue
yield i
# Exclude further multiples of the current prime number
if i <= sub_limit:
flags &= ~multiples_of(i * 3, i << 1, limit)
The multiples_of function avoids the cost of setting every single multiple individually. It sets the initial bit, shifts it by the initial step (i << 1) and adds the result to itself. It then doubles the step, so that the next shift copies both bits, and so on until it reaches the limit. This sets all the multiples of a number up to the limit in O(log(limit)) time.
One option you may want to look at is just compiling a C/C++ module so you have direct access to the bit-twiddling features. AFAIK you could write something of that nature and only return the results on completion of the calculations performed in C/C++. Now that I'm typing this you may also look at numpy which does store arrays as native C for speed. I don't know if that will be any faster than the bitstring module, though!
Another really stupid option, but that can be of help if you really need a large list of primes number available very fast. Say, if you need them as a tool to answer project Euler's problems (if the problem is just optimizing your code as a mind game it's irrelevant).
Use any slow solution to generate list and save it to a python source file, says primes.py that would just contain:
primes = [ a list of a million primes numbers here ]
Now to use them you just have to do import primes and you get them with minimal memory footprint at the speed of disk IO. Complexity is number of primes :-)
Even if you used a poorly optimized solution to generate this list, it will only be done once and it does not matter much.
You could probably make it even faster using pickle/unpickle, but you get the idea...
You could use a segmented Sieve of Eratosthenes. The memory used for each segment is reused for the next segment.
Here is some Python3 code that uses less memory than the bitarray/bitstring solutions posted to date and about 1/8 the memory of Robert William Hanks's primesgen(), while running faster than primesgen() (marginally faster at 1,000,000, using 37KB of memory, 3x faster than primesgen() at 1,000,000,000 using under 34MB). Increasing the chunk size (variable chunk in the code) uses more memory but speeds up the program, up to a limit -- I picked a value so that its contribution to memory is under about 10% of sieve's n//30 bytes. It's not as memory-efficient as Will Ness's infinite generator (see also https://stackoverflow.com/a/19391111/5439078) because it records (and returns at the end, in compressed form) all calculated primes.
This should work correctly as long as the square root calculation comes out accurate (about 2**51 if Python uses 64-bit doubles). However you should not use this program to find primes that large!
(I didn't recalculate the offsets, just took them from code of Robert William Hanks. Thanks Robert!)
import numpy as np
def prime_30_gen(n):
""" Input n, int-like. Yields all primes < n as Python ints"""
wheel = np.array([2,3,5])
yield from wheel[wheel < n].tolist()
powers = 1 << np.arange(8, dtype='u1')
odds = np.array([1, 7, 11, 13, 17, 19, 23, 29], dtype='i8')
offsets=np.array([[0,6,10,12,16,18,22,28],[6,24,16,12,4,0,22,10],
[0,6,20,12,26,18,2,8], [24,6,4,18,16,0,28,10],
[6,24,26,12,14,0,2,20], [0,24,10,18,4,12,28,22],
[24,6,14,18,26,0,8,20], [0,24,20,18,14,12,8,2]],
dtype='i8')
offsets = {d:f for d,f in zip(odds, offsets)}
sieve = 255 * np.ones((n + 28) // 30, dtype='u1')
if n % 30 > 1: sieve[-1] >>= 8 - odds.searchsorted(n % 30)
sieve[0] &= ~1
for i in range((int((n - 1)**0.5) + 29) // 30):
for odd in odds[(sieve[i] & powers).nonzero()[0]]:
k = i * 30 + odd
yield int(k)
for clear_bit, off in zip(~powers, offsets[odd]):
sieve[(k * (k + off)) // 30 :: k] &= clear_bit
chunk = min(1 + (n >> 13), 1<<15)
for i in range(i + 1, len(sieve), chunk):
a = (sieve[i : i + chunk, np.newaxis] & powers)
a = np.flatnonzero(a.astype('bool')) + i * 8
a = (a // 8 * 30 + odds[a & 7]).tolist()
yield from a
return sieve
Side note: If you want real speed and have the 2GB of RAM required for the list of primes to 10**9, then you should use pyprimesieve (on https://pypi.python.org/, using primesieve http://primesieve.org/).

Resources