Optimizing a primality test based on runtime in Python - python-3.x

I'm pretty new to algorithms and runtimes, and I'm trying to optimise a bit of my code for a personal project.
import math
for num in range(0, 10000000000000000000000):
if all((num**(num+1)+(num+1)**(num))%i!=0 for i in range(2,int(math.sqrt((num**(num+1)+(num+1)**(num))))+1)):
print(num)
What can I do to speed this up? I know that num=80 should work but my code isn't getting past num=0, 1, 2 (it's not fast enough).
First I define my range, then I say if 'such-and-such' is prime from range 2 to sqrt(such-and-such) + 1, then return that number. Sqrt(n) + 1 is the minimum number of factors to test for the primality of n.
This is a primality test of sequence A051442

You would probably get a minor boost from computing (num**(num+1)+(num+1)**(num)) only once per iteration instead of sqrt(num**(num+1)+(num+1)**(num)) times. As you can see, this will greatly reduce the constant factor in your complexity. It won't change the fundamental complexity because you still need to compute the remainder. Change
if all((num**(num+1)+(num+1)**(num))%i!=0 for i in range(2,int(math.sqrt((num**(num+1)+(num+1)**(num))))+1)):
to
k = num**(num+1)+(num+1)**(num)
if all(k%i for i in range(2,int(math.sqrt(k))+1)):
The != 0 is implicit in Python.
Update
All this is just trivial improvement to an extremely inefficieny algorithm. The biggest speedup I can think of is to reduce the check k % i to only prime i. For any composite i = a * b such that k % i == 0, it must be the case that k % a == 0 and k % b == 0 (if k is divisible by i, it must also be divisible by the factors of i).
I am assuming that you don't want to use any kind of pre-computed prime tables in your code. In that case, you can compute the table yourself. This will involve checking all the numbers up to a given sqrt(k) only once ever, instead of once per iteration of num, since we can stash the previously computed primes in say a list. That will effectively increase the lower limit of the range in your current all from 2 to the square root of the previous k.
Let's define a function to extend our set of primes using the seive of Eratosthenes:
from math import sqrt
def extend(primes, from_, to):
"""
primes: a sequence containing prime numbers from 2 to `from - 1`, in order
from_: the number to start checking with
to: the number to end with (inclusive)
"""
if not primes:
primes.extend([2, 3])
return
for k in range(max(from_, 5), to + 1):
s = int(sqrt(k)) # No need to compute this more than once per k
for p in primes:
if p > s:
# Reached sqrt(k) -> short circuit success
primes.append(k)
break
elif not k % p:
# Found factor -> short circuit failure
break
Now we can use this function to extend our list of primes at every iteration of the original loop. This allows us to check the divisibility of k only against the slowly growing list of primes, not against all numbers:
primes = []
prev = 0
for num in range(10000000000000000000000):
k = num**(num + 1) + (num + 1)**num
lim = int(sqrt(k)) + 1
extend(primes, prev, lim)
#print('Num={}, K={}, checking {}-{}, p={}'.format(num, k, prev, lim, primes), end='... ')
if k <= 3 and k in primes or all(k % i for i in primes):
print('{}: {} Prime!'.format(num, k))
else:
print('{}: {} Nope'.format(num, k))
prev = lim + 1
I am not 100% sure that my extend function is optimal, but I am able to get to num == 13, k == 4731091158953433 in <10 minutes on my ridiculously old and slow laptop, so I guess it's not too bad. That means that the algorithm builds a complete table of primes up to ~7e7 in that time.
Update #2
A sort-of-but-not-really optimization you could do would be to check all(k % i for i in primes) before calling extend. This would save you a lot of cycles for numbers that have small prime factors, but would probably catch up to you later on, when you would end up having to compute all the primes up to some enormous number. Here is a sample of how you could do that:
primes = []
prev = 0
for num in range(10000000000000000000000):
k = num**(num + 1) + (num + 1)**num
lim = int(sqrt(k)) + 1
if not all(k % i for i in primes):
print('{}: {} Nope'.format(num, k))
continue
start = len(primes)
extend(primes, prev, lim)
if all(k % i for i in primes[start:]):
print('{}: {} Prime!'.format(num, k))
else:
print('{}: {} Nope'.format(num, k))
prev = lim + 1
While this version does not do much for the long run, it does explain why you were able to get to 15 so quickly in your original run. The prime table does note get extended after num == 3, until num == 16, which is when the terrible delay occurs in this version as well. The net runtime to 16 should be identical in both versions.
Update #3
As #paxdiablo suggests, the only numbers we need to consider in extend are multiples of 6 +/- 1. We can combine that with the fact that only a small number of primes generally need to be tested, and convert the functionality of extend into a generator that will only compute as many primes as absolutely necessary. Using Python's lazy generation should help. Here is a completely rewritten version:
from itertools import count
from math import ceil, sqrt
prime_table = [2, 3]
def prime_candidates(start=0):
"""
Infinite generator of prime number candidates starting with the
specified number.
Candidates are 2, 3 and all numbers that are of the form 6n-1 and 6n+1
"""
if start <= 3:
if start <= 2:
yield 2
yield 3
start = 5
delta = 2
else:
m = start % 6
if m < 2:
start += 1 - m
delta = 4
else:
start += 5 - m
delta = 2
while True:
yield start
start += delta
delta = 6 - delta
def isprime(n):
"""
Checks if `n` is prime.
All primes up to sqrt(n) are expected to already be present in
the generated `prime_table`.
"""
s = int(ceil(sqrt(n)))
for p in prime_table:
if p > s:
break
if not n % p:
return False
return True
def generate_primes(max):
"""
Generates primes up to the specified maximum.
First the existing table is yielded. Then, the new primes are
found in the sequence generated by `prime_candidates`. All verified
primes are added to the existing cache.
"""
for p in prime_table:
if p > max:
return
yield p
for k in prime_candidates(prime_table[-1] + 1):
if isprime(k):
prime_table.append(k)
if k > max:
# Putting the return here ensures that we always stop on a prime and therefore don't do any extra work
return
else:
yield k
for num in count():
k = num**(num + 1) + (num + 1)**num
lim = int(ceil(sqrt(k)))
b = all(k % i for i in generate_primes(lim))
print('n={}, k={} is {}prime'.format(num, k, '' if b else 'not '))
This version gets to 15 almost instantly. It gets stuck at 16 because the smallest prime factor for k=343809097055019694337 is 573645313. Some future expectations:
17 should be a breeze: 16248996011806421522977 has factor 19
18 will take a while: 812362695653248917890473 has factor 22156214713
19 is easy: 42832853457545958193355601 is divisible by 3
20 also easy: 2375370429446951548637196401 is divisible by 58967
21: 138213776357206521921578463913 is divisible by 13
22: 8419259736788826438132968480177 is divisible by 103
etc... (link to sequence)
So in terms of instant gratification, this method will get you much further if you can make it past 18 (which will take >100 times longer than getting past 16, which in my case took ~1.25hrs).
That being said, your greatest speedup at this point would be re-writing this in C or some similar low-level language that does not have as much overhead for loops.
Update #4
Just for giggles, here is an implementation of the latest Python version in C. I chose to go with GMP for arbitrary precision integers, because it is easy to use and install on my Red Hat system, and the docs are very clear:
#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>
typedef struct {
size_t alloc;
size_t size;
mpz_t *numbers;
} PrimeTable;
void init_table(PrimeTable *buf)
{
buf->alloc = 0x100000L;
buf->size = 2;
buf->numbers = malloc(buf->alloc * sizeof(mpz_t));
if(buf == NULL) {
fprintf(stderr, "No memory for prime table\n");
exit(1);
}
mpz_init_set_si(buf->numbers[0], 2);
mpz_init_set_si(buf->numbers[1], 3);
return;
}
void append_table(PrimeTable *buf, mpz_t number)
{
if(buf->size == buf->alloc) {
size_t new = 2 * buf->alloc;
mpz_t *tmp = realloc(buf->numbers, new * sizeof(mpz_t));
if(tmp == NULL) {
fprintf(stderr, "Ran out of memory for prime table\n");
exit(1);
}
buf->alloc = new;
buf->numbers = tmp;
}
mpz_set(buf->numbers[buf->size], number);
buf->size++;
return;
}
size_t print_table(PrimeTable *buf, FILE *file)
{
size_t i, n;
n = fprintf(file, "Array contents = [");
for(i = 0; i < buf->size; i++) {
n += mpz_out_str(file, 10, buf->numbers[i]);
if(i < buf->size - 1)
n += fprintf(file, ", ");
}
n += fprintf(file, "]\n");
return n;
}
void free_table(PrimeTable *buf)
{
for(buf->size--; ((signed)(buf->size)) >= 0; buf->size--)
mpz_clear(buf->numbers[buf->size]);
free(buf->numbers);
return;
}
int isprime(mpz_t num, PrimeTable *table)
{
mpz_t max, rem, next;
size_t i, d, r;
mpz_inits(max, rem, NULL);
mpz_sqrtrem(max, rem, num);
// Check if perfect square: definitely not prime
if(!mpz_cmp_si(rem, 0)) {
mpz_clears(rem, max, NULL);
return 0;
}
/* Normal table lookup */
for(i = 0; i < table->size; i++) {
// Got to sqrt(n) -> prime
if(mpz_cmp(max, table->numbers[i]) < 0) {
mpz_clears(rem, max, NULL);
return 1;
}
// Found a factor -> not prime
if(mpz_divisible_p(num, table->numbers[i])) {
mpz_clears(rem, max, NULL);
return 0;
}
}
/* Extend table and do lookup */
// Start with last found prime + 2
mpz_init_set(next, table->numbers[i - 1]);
mpz_add_ui(next, next, 2);
// Find nearest number of form 6n-1 or 6n+1
r = mpz_fdiv_ui(next, 6);
if(r < 2) {
mpz_add_ui(next, next, 1 - r);
d = 4;
} else {
mpz_add_ui(next, next, 5 - r);
d = 2;
}
// Step along numbers of form 6n-1/6n+1. Check each candidate for
// primality. Don't stop until next prime after sqrt(n) to avoid
// duplication.
for(;;) {
if(isprime(next, table)) {
append_table(table, next);
if(mpz_divisible_p(num, next)) {
mpz_clears(next, rem, max, NULL);
return 0;
}
if(mpz_cmp(max, next) <= 0) {
mpz_clears(next, rem, max, NULL);
return 1;
}
}
mpz_add_ui(next, next, d);
d = 6 - d;
}
// Return can only happen from within loop.
}
int main(int argc, char *argv[])
{
PrimeTable table;
mpz_t k, a, b;
size_t n, next;
int p;
init_table(&table);
mpz_inits(k, a, b, NULL);
for(n = 0; ; n = next) {
next = n + 1;
mpz_set_ui(a, n);
mpz_pow_ui(a, a, next);
mpz_set_ui(b, next);
mpz_pow_ui(b, b, n);
mpz_add(k, a, b);
p = isprime(k, &table);
printf("n=%ld k=", n);
mpz_out_str(stdout, 10, k);
printf(" p=%d\n", p);
//print_table(&table, stdout);
}
mpz_clears(b, a, k, NULL);
free_table(&table);
return 0;
}
While this version has the exact same algorithmic complexity as the Python one, I expect it to run a few orders of magnitude faster because of the relatively minimal overhead incurred in C. And indeed, it took about 15 minutes to get stuck at n == 18, which is ~5 times faster than the Python version so far.
Update #5
This is going to be the last one, I promise.
GMP has a function called mpz_nextprime, which offers a potentially much faster implementation of this algorightm, especially with caching. According to the docs:
This function uses a probabilistic algorithm to identify primes. For practical purposes it’s adequate, the chance of a composite passing will be extremely small.
This means that it is probably much faster than the current prime generator I implemented, with a slight cost offset of some false primes being added to the cache. This cost should be minimal: even adding a few thousand extra modulo operations should be fine if the prime generator is faster than it is now.
The only part that needs to be replaced/modified is the portion of isprime below the comment /* Extend table and do lookup */. Basically that whole section just becomes a series of calls to mpz_nextprime instead of recursion.
At that point, you may as well adapt isprime to use mpz_probab_prime_p when possible. You only need to check for sure if the result of mpz_probab_prime_p is uncertain:
int isprime(mpz_t num, PrimeTable *table)
{
mpz_t max, rem, next;
size_t i, r;
int status;
status = mpz_probab_prime_p(num, 50);
// Status = 2 -> definite yes, Status = 0 -> definite no
if(status != 1)
return status != 0;
mpz_inits(max, rem, NULL);
mpz_sqrtrem(max, rem, num);
// Check if perfect square: definitely not prime
if(!mpz_cmp_si(rem, 0)) {
mpz_clears(rem, max, NULL);
return 0;
}
mpz_clear(rem);
/* Normal table lookup */
for(i = 0; i < table->size; i++) {
// Got to sqrt(n) -> prime
if(mpz_cmp(max, table->numbers[i]) < 0) {
mpz_clear(max);
return 1;
}
// Found a factor -> not prime
if(mpz_divisible_p(num, table->numbers[i])) {
mpz_clear(max);
return 0;
}
}
/* Extend table and do lookup */
// Start with last found prime + 2
mpz_init_set(next, table->numbers[i - 1]);
mpz_add_ui(next, next, 2);
// Step along probable primes
for(;;) {
mpz_nextprime(next, next);
append_table(table, next);
if(mpz_divisible_p(num, next)) {
r = 0;
break;
}
if(mpz_cmp(max, next) <= 0) {
r = 1;
break;
}
}
mpz_clears(next, max, NULL);
return r;
}
Sure enough, this version makes it to n == 79 in a couple of seconds at most. It appears to get stuck on n == 80, probably because mpz_probab_prime_p can't determine if k is a prime for sure. I doubt that computing all the primes up to ~10^80 is going to take a trivial amount of time.

Related

What is the worst case for binary search

Where should an element be located in the array so that the run time of the Binary search algorithm is O(log n)?
The first or last element will give the worst case complexity in binary search as you'll have to do maximum no of comparisons.
Example:
1 2 3 4 5 6 7 8 9
Here searching for 1 will give you the worst case, with the result coming in 4th pass.
1 2 3 4 5 6 7 8
In this case, searching for 8 will give the worst case, with the result coming in 4 passes.
Note that in the second case searching for 1 (the first element) can be done in just 3 passes. (compare 1 & 4, compare 1 & 2 and finally 1)
So, if no. of elements are even, the last element gives the worst case.
This is assuming all arrays are 0 indexed. This happens due to considering the mid as float of (start + end) /2.
// Java implementation of iterative Binary Search
class BinarySearch
{
// Returns index of x if it is present in arr[],
// else return -1
int binarySearch(int arr[], int x)
{
int l = 0, r = arr.length - 1;
while (l <= r)
{
int m = l + (r-l)/2;
// Check if x is present at mid
if (arr[m] == x)
return m;
// If x greater, ignore left half
if (arr[m] < x)
l = m + 1;
// If x is smaller, ignore right half
else
r = m - 1;
}
// if we reach here, then element was
// not present
return -1;
}
// Driver method to test above
public static void main(String args[])
{
BinarySearch ob = new BinarySearch();
int arr[] = {2, 3, 4, 10, 40};
int n = arr.length;
int x = 10;
int result = ob.binarySearch(arr, x);
if (result == -1)
System.out.println("Element not present");
else
System.out.println("Element found at " +
"index " + result);
}
}
Time Complexity:
The time complexity of Binary Search can be written as
T(n) = T(n/2) + c
The above recurrence can be solved either using Recurrence T ree method or Master method. It falls in case II of Master Method and solution of the recurrence is Theta(Logn).
Auxiliary Space: O(1) in case of iterative implementation. In case of recursive implementation, O(Logn) recursion call stack space.

Maximum element in array which is equal to product of two elements in array

We need to find the maximum element in an array which is also equal to product of two elements in the same array. For example [2,3,6,8] , here 6=2*3 so answer is 6.
My approach was to sort the array and followed by a two pointer method which checked whether the product exist for each element. This is o(nlog(n)) + O(n^2) = O(n^2) approach. Is there a faster way to this ?
There is a slight better solution with O(n * sqrt(n)) if you are allowed to use O(M) memory M = max number in A[i]
Use an array of size M to mark every number while you traverse them from smaller to bigger number.
For each number try all its factors and see if those were already present in the array map.
Here is a pseudo code for that:
#define M 1000000
int array_map[M+2];
int ans = -1;
sort(A,A+n);
for(i=0;i<n;i++) {
for(j=1;j<=sqrt(A[i]);j++) {
int num1 = j;
if(A[i]%num1==0) {
int num2 = A[i]/num1;
if(array_map[num1] && array_map[num2]) {
if(num1==num2) {
if(array_map[num1]>=2) ans = A[i];
} else {
ans = A[i];
}
}
}
}
array_map[A[i]]++;
}
There is an ever better approach if you know how to find all possible factors in log(M) this just becomes O(n*logM). You have to use sieve and backtracking for that
#JerryGoyal 's solution is correct. However, I think it can be optimized even further if instead of using B pointer, we use binary search to find the other factor of product if arr[c] is divisible by arr[a]. Here's the modification for his code:
for(c=n-1;(c>1)&& (max==-1);c--){ // loop through C
for(a=0;(a<c-1)&&(max==-1);a++){ // loop through A
if(arr[c]%arr[a]==0) // If arr[c] is divisible by arr[a]
{
if(binary_search(a+1, c-1, (arr[c]/arr[a]))) //#include<algorithm>
{
max = arr[c]; // if the other factor x of arr[c] is also in the array such that arr[c] = arr[a] * x
break;
}
}
}
}
I would have commented this on his solution, unfortunately I lack the reputation to do so.
Try this.
Written in c++
#include <vector>
#include <algorithm>
using namespace std;
int MaxElement(vector< int > Input)
{
sort(Input.begin(), Input.end());
int LargestElementOfInput = 0;
int i = 0;
while (i < Input.size() - 1)
{
if (LargestElementOfInput == Input[Input.size() - (i + 1)])
{
i++;
continue;
}
else
{
if (Input[i] != 0)
{
LargestElementOfInput = Input[Input.size() - (i + 1)];
int AllowedValue = LargestElementOfInput / Input[i];
int j = 0;
while (j < Input.size())
{
if (Input[j] > AllowedValue)
break;
else if (j == i)
{
j++;
continue;
}
else
{
int Product = Input[i] * Input[j++];
if (Product == LargestElementOfInput)
return Product;
}
}
}
i++;
}
}
return -1;
}
Once you have sorted the array, then you can use it to your advantage as below.
One improvement I can see - since you want to find the max element that meets the criteria,
Start from the right most element of the array. (8)
Divide that with the first element of the array. (8/2 = 4).
Now continue with the double pointer approach, till the element at second pointer is less than the value from the step 2 above or the match is found. (i.e., till second pointer value is < 4 or match is found).
If the match is found, then you got the max element.
Else, continue the loop with next highest element from the array. (6).
Efficient solution:
2 3 8 6
Sort the array
keep 3 pointers C, B and A.
Keeping C at the last and A at 0 index and B at 1st index.
traverse the array using pointers A and B till C and check if A*B=C exists or not.
If it exists then C is your answer.
Else, Move C a position back and traverse again keeping A at 0 and B at 1st index.
Keep repeating this till you get the sum or C reaches at 1st index.
Here's the complete solution:
int arr[] = new int[]{2, 3, 8, 6};
Arrays.sort(arr);
int n=arr.length;
int a,b,c,prod,max=-1;
for(c=n-1;(c>1)&& (max==-1);c--){ // loop through C
for(a=0;(a<c-1)&&(max==-1);a++){ // loop through A
for(b=a+1;b<c;b++){ // loop through B
prod=arr[a]*arr[b];
if(prod==arr[c]){
System.out.println("A: "+arr[a]+" B: "+arr[b]);
max=arr[c];
break;
}
if(prod>arr[c]){ // no need to go further
break;
}
}
}
}
System.out.println(max);
I came up with below solution where i am using one array list, and following one formula:
divisor(a or b) X quotient(b or a) = dividend(c)
Sort the array.
Put array into Collection Col.(ex. which has faster lookup, and maintains insertion order)
Have 2 pointer a,c.
keep c at last, and a at 0.
try to follow (divisor(a or b) X quotient(b or a) = dividend(c)).
Check if a is divisor of c, if yes then check for b in col.(a
If a is divisor and list has b, then c is the answer.
else increase a by 1, follow step 5, 6 till c-1.
if max not found then decrease c index, and follow the steps 4 and 5.
Check this C# solution:
-Loop through each element,
-loop and multiply each element with other elements,
-verify if the product exists in the array and is the max
private static int GetGreatest(int[] input)
{
int max = 0;
int p = 0; //product of pairs
//loop through the input array
for (int i = 0; i < input.Length; i++)
{
for (int j = i + 1; j < input.Length; j++)
{
p = input[i] * input[j];
if (p > max && Array.IndexOf(input, p) != -1)
{
max = p;
}
}
}
return max;
}
Time complexity O(n^2)

speed-ups to compute the number of strings over an alphabet that have no repeated substrings of length > 1

So it seems no one knows a formula in terms of N for the number of strings over an alphabet of size N that each have no repeated substring of length >= 2. The number of De Bruijn sequences can be used to provide a lower bound. However what if we want to compute the exact number for as large of N as possible? Are there programming tricks (symmetry, etc.) that can be used to compute the number of such strings for decent sized N?
The obvious symmetry to exploit is permuting the letter identities to generate only lexicographically minimum representatives. The following Python code uses a recursive search with pruning. It doesn't get very far, so I'm not sure if you think of it as "brute force".
There might be a silver bullet waiting, but I suspect not, in which case pushing N will be like digging a tunnel with a rock hammer.
import math
def nonrep(n, s=''):
a = (ord(min(s)) if s else ord('a'))
c = ((ord(max(s)) + 1) if s else a)
total = (math.factorial(n) // math.factorial((n - (c - a))))
for b in range(a, min((c + 1), (a + n))):
t = (s + chr(b))
if (t[(- 2):] not in s):
total += nonrep(n, t)
return total
for k in range(1, 12):
print(nonrep(k))
The next step is to memoize by used substrings. The following C++11 code can do N=6 (2131886084545954033) fairly quickly on a machine with a ton of memory.
#include <cstdio>
#include <unordered_map>
namespace {
static const int kN(6);
long long Count(std::unordered_map<unsigned long long, long long>* memo,
unsigned long long used,
int end) {
auto got(memo->find(used));
if (got != memo->end()) {
return got->second;
}
long long total(1);
for (int j(0); j != kN; ++j) {
unsigned long long bit(1ULL << (j * kN + end));
if (used < (1ULL << (j * kN)) && j != 0) {
total += (kN - j) * Count(memo, used | bit, j);
break;
} else if ((used & bit) == 0) {
total += Count(memo, used | bit, j);
}
}
if (total >= 100) {
(*memo)[used] = total;
}
return total;
}
} // namespace
int main(void) {
std::unordered_map<unsigned long long, long long> memo;
std::printf("%lld\n", 1 + kN * Count(&memo, 0, 0));
}
The next step after this probably is more symmetry breaking, e.g., by intelligently permuting the letters further to canonize used, but I think it's going to be a long way to N=7.

Minimum no. of comparisons to find median of 3 numbers

I was implementing quicksort and I wished to set the pivot to be the median or three numbers. The three numbers being the first element, the middle element, and the last element.
Could I possibly find the median in less no. of comparisons?
median(int a[], int p, int r)
{
int m = (p+r)/2;
if(a[p] < a[m])
{
if(a[p] >= a[r])
return a[p];
else if(a[m] < a[r])
return a[m];
}
else
{
if(a[p] < a[r])
return a[p];
else if(a[m] >= a[r])
return a[m];
}
return a[r];
}
If the concern is only comparisons, then this should be used.
int getMedian(int a, int b , int c) {
int x = a-b;
int y = b-c;
int z = a-c;
if(x*y > 0) return b;
if(x*z > 0) return c;
return a;
}
int32_t FindMedian(const int n1, const int n2, const int n3) {
auto _min = min(n1, min(n2, n3));
auto _max = max(n1, max(n2, n3));
return (n1 + n2 + n3) - _min - _max;
}
You can't do it in one, and you're only using two or three, so I'd say you've got the minimum number of comparisons already.
Rather than just computing the median, you might as well put them in place. Then you can get away with just 3 comparisons all the time, and you've got your pivot closer to being in place.
T median(T a[], int low, int high)
{
int middle = ( low + high ) / 2;
if( a[ middle ].compareTo( a[ low ] ) < 0 )
swap( a, low, middle );
if( a[ high ].compareTo( a[ low ] ) < 0 )
swap( a, low, high );
if( a[ high ].compareTo( a[ middle ] ) < 0 )
swap( a, middle, high );
return a[middle];
}
I know that this is an old thread, but I had to solve exactly this problem on a microcontroller that has very little RAM and does not have a h/w multiplication unit (:)). In the end I found the following works well:
static char medianIndex[] = { 1, 1, 2, 0, 0, 2, 1, 1 };
signed short getMedian(const signed short num[])
{
return num[medianIndex[(num[0] > num[1]) << 2 | (num[1] > num[2]) << 1 | (num[0] > num[2])]];
}
If you're not afraid to get your hands a little dirty with compiler intrinsics you can do it with exactly 0 branches.
The same question was discussed before on:
Fastest way of finding the middle value of a triple?
Though, I have to add that in the context of naive implementation of quicksort, with a lot of elements, reducing the amount of branches when finding the median is not so important because the branch predictor will choke either way when you'll start tossing elements around the the pivot. More sophisticated implementations (which don't branch on the partition operation, and avoid WAW hazards) will benefit from this greatly.
remove max and min value from total sum
int med3(int a, int b, int c)
{
int tot_v = a + b + c ;
int max_v = max(a, max(b, c));
int min_v = min(a, min(b, c));
return tot_v - max_v - min_v
}
There is actually a clever way to isolate the median element from three using a careful analysis of the 6 possible permutations (of low, median, high). In python:
def med(a, start, mid, last):
# put the median of a[start], a[mid], a[last] in the a[start] position
SM = a[start] < a[mid]
SL = a[start] < a[last]
if SM != SL:
return
ML = a[mid] < a[last]
m = mid if SM == ML else last
a[start], a[m] = a[m], a[start]
Half the time you have two comparisons otherwise you have 3 (avg 2.5). And you only swap the median element once when needed (2/3 of the time).
Full python quicksort using this at:
https://github.com/mckoss/labs/blob/master/qs.py
You can write up all the permutations:
1 0 2
1 2 0
0 1 2
2 1 0
0 2 1
2 0 1
Then we want to find the position of the 1. We could do this with two comparisons, if our first comparison could split out a group of equal positions, such as the first two lines.
The issue seems to be that the first two lines are different on any comparison we have available: a<b, a<c, b<c. Hence we have to fully identify the permutation, which requires 3 comparisons in the worst case.
Using a Bitwise XOR operator, the median of three numbers can be found.
def median(a,b,c):
m = max(a,b,c)
n = min(a,b,c)
ans = m^n^a^b^c
return ans

Generate all compositions of an integer into k parts

I can't figure out how to generate all compositions (http://en.wikipedia.org/wiki/Composition_%28number_theory%29) of an integer N into K parts, but only doing it one at a time. That is, I need a function that given the previous composition generated, returns the next one in the sequence. The reason is that memory is limited for my application. This would be much easier if I could use Python and its generator functionality, but I'm stuck with C++.
This is similar to Next Composition of n into k parts - does anyone have a working algorithm?
Any assistance would be greatly appreciated.
Preliminary remarks
First start from the observation that [1,1,...,1,n-k+1] is the first composition (in lexicographic order) of n over k parts, and [n-k+1,1,1,...,1] is the last one.
Now consider an exemple: the composition [2,4,3,1,1], here n = 11 and k=5. Which is the next one in lexicographic order? Obviously the rightmost part to be incremented is 4, because [3,1,1] is the last composition of 5 over 3 parts.
4 is at the left of 3, the rightmost part different from 1.
So turn 4 into 5, and replace [3,1,1] by [1,1,2], the first composition of the remainder (3+1+1)-1 , giving [2,5,1,1,2]
Generation program (in C)
The following C program shows how to compute such compositions on demand in lexicographic order
#include <stdio.h>
#include <stdbool.h>
bool get_first_composition(int n, int k, int composition[k])
{
if (n < k) {
return false;
}
for (int i = 0; i < k - 1; i++) {
composition[i] = 1;
}
composition[k - 1] = n - k + 1;
return true;
}
bool get_next_composition(int n, int k, int composition[k])
{
if (composition[0] == n - k + 1) {
return false;
}
// there'a an i with composition[i] > 1, and it is not 0.
// find the last one
int last = k - 1;
while (composition[last] == 1) {
last--;
}
// turn a b ... y z 1 1 ... 1
// ^ last
// into a b ... (y+1) 1 1 1 ... (z-1)
// be careful, there may be no 1's at the end
int z = composition[last];
composition[last - 1] += 1;
composition[last] = 1;
composition[k - 1] = z - 1;
return true;
}
void display_composition(int k, int composition[k])
{
char *separator = "[";
for (int i = 0; i < k; i++) {
printf("%s%d", separator, composition[i]);
separator = ",";
}
printf("]\n");
}
void display_all_compositions(int n, int k)
{
int composition[k]; // VLA. Please don't use silly values for k
for (bool exists = get_first_composition(n, k, composition);
exists;
exists = get_next_composition(n, k, composition)) {
display_composition(k, composition);
}
}
int main()
{
display_all_compositions(5, 3);
}
Results
[1,1,3]
[1,2,2]
[1,3,1]
[2,1,2]
[2,2,1]
[3,1,1]
Weak compositions
A similar algorithm works for weak compositions (where 0 is allowed).
bool get_first_weak_composition(int n, int k, int composition[k])
{
if (n < k) {
return false;
}
for (int i = 0; i < k - 1; i++) {
composition[i] = 0;
}
composition[k - 1] = n;
return true;
}
bool get_next_weak_composition(int n, int k, int composition[k])
{
if (composition[0] == n) {
return false;
}
// there'a an i with composition[i] > 0, and it is not 0.
// find the last one
int last = k - 1;
while (composition[last] == 0) {
last--;
}
// turn a b ... y z 0 0 ... 0
// ^ last
// into a b ... (y+1) 0 0 0 ... (z-1)
// be careful, there may be no 0's at the end
int z = composition[last];
composition[last - 1] += 1;
composition[last] = 0;
composition[k - 1] = z - 1;
return true;
}
Results for n=5 k=3
[0,0,5]
[0,1,4]
[0,2,3]
[0,3,2]
[0,4,1]
[0,5,0]
[1,0,4]
[1,1,3]
[1,2,2]
[1,3,1]
[1,4,0]
[2,0,3]
[2,1,2]
[2,2,1]
[2,3,0]
[3,0,2]
[3,1,1]
[3,2,0]
[4,0,1]
[4,1,0]
[5,0,0]
Similar algorithms can be written for compositions of n into k parts greater than some fixed value.
You could try something like this:
start with the array [1,1,...,1,N-k+1] of (K-1) ones and 1 entry with the remainder. The next composition can be created by incrementing the (K-1)th element and decreasing the last element. Do this trick as long as the last element is bigger than the second to last.
When the last element becomes smaller, increment the (K-2)th element, set the (K-1)th element to the same value and set the last element to the remainder again. Repeat the process and apply the same principle for the other elements when necessary.
You end up with a constantly sorted array that avoids duplicate compositions

Resources