Idiomatic implementation of the tribonacci sequence in Rust - rust

I’m new to Rust, but as a fan of Haskell, I greatly appreciate the way match works in Rust. Now I’m faced with the rare case where I do need fall-through – in the sense that I would like all matching cases of several overlapping ones to be executed. This works:
fn options(stairs: i32) -> i32 {
if stairs == 0 {
return 1;
}
let mut count: i32 = 0;
if stairs >= 1 {
count += options(stairs - 1);
}
if stairs >= 2 {
count += options(stairs - 2);
}
if stairs >= 3 {
count += options(stairs - 3);
}
count
}
My question is whether this is idiomatic in Rust or whether there is a better way.
The context is a question from Cracking the Coding Interview: “A child is running up a staircase with n steps and can hop either 1 step, 2 steps, or 3 steps at a time. Implement a method to count how many possible ways the child can run up the stairs.”

Based on the definition of the tribonacci sequence I found you could write it in a more concise manner like this:
fn options(stairs: i32) -> i32 {
match stairs {
0 => 0,
1 => 1,
2 => 1,
3 => 2,
_ => options(stairs - 1) + options(stairs - 2) + options(stairs - 3)
}
}
I would also recommend changing the funtion definition to only accept positive integers, e.g. u32.

To answer the generic question, I would argue that match and fallthrough are somewhat antithetical.
match is used to be able to perform different actions based on the different patterns. Most of the time, the very values extracted via pattern matching are so different than a fallthrough does not make sense.
A fallthrough, instead, points to a sequence of actions. There are many ways to express sequences: recursion, iteration, ...
In your case, for example, one could use a loop:
for i in 1..4 {
if stairs >= i {
count += options(stairs - i);
}
}
Of course, I find #ljedrz' solution even more elegant in this particular instance.

I would advise to avoid recursion in Rust. It is better to use iterators:
struct Trib(usize, usize, usize);
impl Default for Trib {
fn default() -> Trib {
Trib(1, 0, 0)
}
}
impl Iterator for Trib {
type Item = usize;
fn next(&mut self) -> Option<usize> {
let &mut Trib(a, b, c) = self;
let d = a + b + c;
*self = Trib(b, c, d);
Some(d)
}
}
fn options(stairs: usize) -> usize {
Trib::default().take(stairs + 1).last().unwrap()
}
fn main() {
for (i, v) in Trib::default().enumerate().take(10) {
println!("i={}, t={}", i, v);
}
println!("{}", options(0));
println!("{}", options(1));
println!("{}", options(3));
println!("{}", options(7));
}
Playground

Your code looks pretty idiomatic to me, although #ljedrz has suggested an even more elegant rewriting of the same strategy.
Since this is an interview problem, it's worth mentioning that neither solution is going to be seen as an amazing answer because both solutions take exponential time in the number of stairs.
Here is what I might write if I were trying to crack a coding interview:
fn options(stairs: usize) -> u128 {
let mut o = vec![1, 1, 2, 4];
for _ in 3..stairs {
o.push(o[o.len() - 1] + o[o.len() - 2] + o[o.len() - 3]);
}
o[stairs]
}
Instead of recomputing options(n) each time, we cache each value in an array. So, this should run in linear time instead of exponential time. I also switched to a u128 to be able to return solutions for larger inputs.
Keep in mind that this is not the most efficient solution because it uses linear space. You can get away with using constant space by only keeping track of the final three elements of the array. I chose this as a compromise between conciseness, readability, and efficiency.

Related

BigUint binary complement

Example code:
use num_bigint::BigUint;
use num_traits::identities::One;
fn main() {
// Example: 10001 (17) => 1110 (14)
let n = BigUint::from(17u32);
println!("{}", n);
// BigUint doesn't support `!`
//let n = !n;
let mask = (BigUint::one() << n.bits()) - 1u32;
let n = n ^ mask;
println!("{}", n);
}
The above code is doing a binary complement of a BigUint using a bit mask. Questions:
Is there a better way to do the binary complement than with a mask? It seems BigUint doesn't include the ! operator (but the mask may be necessary anyway depending on how ! was defined).
If not is there a better way to generate the mask? (Caching masks helps, but can use lots of memory.)
More context with the problem I'm actually looking at: binary complement sequences
If you alternate multiplying by 3 and bit flipping a number some interesting sequences arise. Example starting with 3:
0. 3 (11b) => 3*3 = 9 (1001b) => bit complement is 6 (0110b)
1. 6 (110b)
2. 13 (1101b)
3. 24 (11000b)
4. 55 (110111b)
5. 90 (1011010b)
6. 241 (11110001b)
7. 300 (100101100b)
8. 123 (1111011b)
9. 142 (10001110b)
11. 85 (1010101b)
12. 0 (0b)
One question is whether it reaches zero for all starting numbers or not. Some meander around for quite a while before reaching zero (425720 takes 87,037,147,316 iterations to reach 0). Being able to compute this efficiently can help in answering these questions. Mostly I'm learning a bit more rust with this though.
If you are looking for performance, num-bigint probably isn't the best choice. Everything that is really high-performance, though, seems to be GPL licensed.
Either way, here is a solution using the rug library, which directly supports !(not), and seems to be really fast:
use rug::{ops::NotAssign, Integer};
fn main() {
// Example: 10001 (17) => 1110 (14)
let mut n = Integer::from(17u32);
println!("{}", n);
n.not_assign();
n.keep_bits_mut(n.significant_bits() - 1);
println!("{}", n);
}
17
14
Note that not_assign also inverts the sign bit. We can remove that bit through the keep_bits_mut function.
For example, here is a version of your algorithm:
use rug::{ops::NotAssign, Integer};
fn step(n: &mut Integer) {
*n *= 3;
n.not_assign();
n.keep_bits_mut(n.significant_bits() - 1);
}
fn main() {
let mut n = Integer::from(3);
println!("{}", n);
while n != 0 {
step(&mut n);
println!("{}", n);
}
}
3
6
13
24
55
90
241
300
123
142
85
0
The best solution is probably to just do it yourself. You perform an allocation each time you create a BigUint which really slows down your program. Since we are not doing complex math, we can simplify most of this to a couple bitwise operations.
After a little bit of tinkering, here is how I implemented it. For convenience, I used the unstable nightly feature bigint_helper_methods to allow for the carrying_add function. This helped simplify the addition process.
#[derive(Debug)]
pub struct BigUintHelper {
words: Vec<u64>,
}
impl BigUintHelper {
pub fn mul3_invert(&mut self) {
let len = self.words.len();
// Multiply everything by 3 by adding it to itself with a bit shift
let mut carry = false;
let mut prev_bit = 0;
for word in &mut self.words[..len - 1] {
let previous = *word;
// Perform the addition operation
let (next, next_carry) = previous.carrying_add((previous << 1) | prev_bit, carry);
// Reset carried values for next round
prev_bit = previous >> (u64::BITS - 1);
carry = next_carry;
// Invert the result as we go to avoid needing another pass
*word = !next;
}
// Perform the last word seperatly since we may need to do the invert differently
let previous = self.words[len - 1];
let (next, next_carry) = previous.carrying_add((previous << 1) | prev_bit, carry);
// Extra word from the combination of the carry bits
match next_carry as u64 + (previous >> (u64::BITS - 1)) {
0 => {
// The carry was 0 so we do the normal process
self.words[len - 1] = invert_bits(next);
self.cleanup_end();
}
1 => {
self.words[len - 1] = !next;
// invert_bits(1) = 0
self.cleanup_end();
}
2 => {
self.words[len - 1] = !next;
// invert_bits(2) = 1
self.words.push(1);
}
_ => unreachable!(),
}
}
/// Remove any high order words without any bits
#[inline(always)]
fn cleanup_end(&mut self) {
while let Some(x) = self.words.pop() {
if x != 0 {
self.words.push(x);
break;
}
}
}
/// Count how many rounds it takes to convert this value to 0.
pub fn count_rounds(&mut self) -> u64 {
let mut rounds = 0;
while !self.words.is_empty() {
self.mul3_invert();
rounds += 1;
}
rounds
}
}
impl From<u64> for BigUintHelper {
fn from(x: u64) -> Self {
BigUintHelper {
words: vec![x],
}
}
}
#[inline(always)]
const fn invert_bits(x: u64) -> u64 {
match x.leading_zeros() {
0 => !x,
y => ((1u64 << (u64::BITS - y)) - 1) ^ x
}
}
Rust Playground

Getting the length of an int

I am trying to get the length (the number of digits when interpreted in decimal) of an int in rust. I found a way to do it, however am looking for method that comes from the primitive itself. This is what I have:
let num = 90.to_string();
println!("num: {}", num.chars().count())
// num: 2
I am looking at https://docs.rs/digits/0.3.3/digits/struct.Digits.html#method.length. is this a good candidate? How do I use it? Or are there other crates that does it for me?
One liners with less type conversion is the ideal solution I am looking for.
You could loop and check how often you can divide the number by 10 before it becomes a single digit.
Or in the other direction (because division is slower than multiplication), check how often you can multiply 10*10*...*10 until you reach the number:
fn length(n: u32, base: u32) -> u32 {
let mut power = base;
let mut count = 1;
while n >= power {
count += 1;
if let Some(new_power) = power.checked_mul(base) {
power = new_power;
} else {
break;
}
}
count
}
With nightly rust (or in the future, when the int_log feature is stabilized), you can use:
#![feature(int_log)]
n.checked_log10().unwrap_or(0) + 1
Here is a one-liner that doesn't require strings or floating point:
println!("num: {}", successors(Some(n), |&n| (n >= 10).then(|| n / 10)).count());
It simply counts the number of times the initial number needs to be divided by 10 in order to reach 0.
EDIT: the first version of this answer used iterate from the (excellent and highly recommended) itertools crate, but #trentcl pointed out that successors from the stdlib does the same. For reference, here is the version using iterate:
println!("num: {}", iterate(n, |&n| n / 10).take_while(|&n| n > 0).count().max(1));
Here's a (barely) one-liner that's faster than doing a string conversion, using std::iter stuff:
let some_int = 9834;
let decimal_places = (0..).take_while(|i| 10u64.pow(*i) <= some_int).count();
The first method below relies on the following formula, where a and b are the logarithmic bases.
log<a>( x ) = log<b>( x ) / log<b>( a )
log<a>( x ) = log<2>( x ) / log<2>( a ) // Substituting 2 for `b`.
The following function can be applied to finding the number of digits for bases that are a power of 2. This approach is very fast.
fn num_digits_base_pow2(n: u64, b: u32) -> u32
{
(63 - n.leading_zeros()) / (31 - b.leading_zeros()) + 1
}
The bits are counted for both n (the number we want to represent) and b (the base) to find their log2 floor values. Then the adjusted ratio of these values gives the ceiling log value in the desired base.
For a general purpose approach to finding the number of digits for arbitrary bases, the following should suffice.
fn num_digits(n: u64, b: u32) -> u32
{
(n as f64).log(b as f64).ceil() as u32
}
if num is signed:
let digits = (num.abs() as f64 + 0.1).log10().ceil() as u32;
A nice property of numbers that is always good to have in mind is that the number of digits required to write a number $x$ in base $n$ is actually $\lceil log_n(x + 1) \rceil$.
Therefore, one can simply write the following function (notice the cast from u32 to f32, since integers don't have a log function).
fn length(n: u32, base: u32) -> u32 {
let n = (n+1) as f32;
n.log(base as f32).ceil() as u32
}
You can easily adapt it for negative numbers. For floating point numbers this might be a bit (i.e. a lot) more tricky.
To take into account Daniel's comment about the pathological cases introduced by using f32, note that, with nightly Rust, integers have a logarithm method. (Notice that, imo, those are implementation details, and you should more focus on understanding the algorithm than the implementation.):
#![feature(int_log)]
fn length(n: u32, base: u32) -> u32 {
n.log(base) + 1
}

How do I create a non-recursive calculation of factorial using iterators and ranges?

I ran into a Rustlings exercise that keeps bugging me:
pub fn factorial(num: u64) -> u64 {
// Complete this function to return factorial of num
// Do not use:
// - return
// For extra fun don't use:
// - imperative style loops (for, while)
// - additional variables
// For the most fun don't use:
// - recursion
// Execute `rustlings hint iterators4` for hints.
}
A hint to solution tells me...
In an imperative language you might write a for loop to iterate
through multiply the values into a mutable variable. Or you might
write code more functionally with recursion and a match clause. But
you can also use ranges and iterators to solve this in rust.
I tried this approach, but I am missing something:
if num > 1 {
(2..=num).map(|n| n * ( n - 1 ) ??? ).???
} else {
1
}
Do I have to use something like .take_while instead of if?
The factorial is defined as the product of all the numbers from a starting number down to 1. We use that definition and Iterator::product:
fn factorial(num: u64) -> u64 {
(1..=num).product()
}
If you look at the implementation of Product for the integers, you'll see that it uses Iterator::fold under the hood:
impl Product for $a {
fn product<I: Iterator<Item=Self>>(iter: I) -> Self {
iter.fold($one, Mul::mul)
}
}
You could hard-code this yourself:
fn factorial(num: u64) -> u64 {
(1..=num).fold(1, |acc, v| acc * v)
}
See also:
How to sum the values in an array, slice, or Vec in Rust?
How do I sum a vector using fold?
Although using .product() or .fold() is probably the best answer, you can also use .for_each().
fn factorial(num: u64) -> u64 {
let mut x = 1;
(1..=num).for_each(|i| x *= i);
x
}

C-style switch statement with fall-through in Rust [duplicate]

I’m new to Rust, but as a fan of Haskell, I greatly appreciate the way match works in Rust. Now I’m faced with the rare case where I do need fall-through – in the sense that I would like all matching cases of several overlapping ones to be executed. This works:
fn options(stairs: i32) -> i32 {
if stairs == 0 {
return 1;
}
let mut count: i32 = 0;
if stairs >= 1 {
count += options(stairs - 1);
}
if stairs >= 2 {
count += options(stairs - 2);
}
if stairs >= 3 {
count += options(stairs - 3);
}
count
}
My question is whether this is idiomatic in Rust or whether there is a better way.
The context is a question from Cracking the Coding Interview: “A child is running up a staircase with n steps and can hop either 1 step, 2 steps, or 3 steps at a time. Implement a method to count how many possible ways the child can run up the stairs.”
Based on the definition of the tribonacci sequence I found you could write it in a more concise manner like this:
fn options(stairs: i32) -> i32 {
match stairs {
0 => 0,
1 => 1,
2 => 1,
3 => 2,
_ => options(stairs - 1) + options(stairs - 2) + options(stairs - 3)
}
}
I would also recommend changing the funtion definition to only accept positive integers, e.g. u32.
To answer the generic question, I would argue that match and fallthrough are somewhat antithetical.
match is used to be able to perform different actions based on the different patterns. Most of the time, the very values extracted via pattern matching are so different than a fallthrough does not make sense.
A fallthrough, instead, points to a sequence of actions. There are many ways to express sequences: recursion, iteration, ...
In your case, for example, one could use a loop:
for i in 1..4 {
if stairs >= i {
count += options(stairs - i);
}
}
Of course, I find #ljedrz' solution even more elegant in this particular instance.
I would advise to avoid recursion in Rust. It is better to use iterators:
struct Trib(usize, usize, usize);
impl Default for Trib {
fn default() -> Trib {
Trib(1, 0, 0)
}
}
impl Iterator for Trib {
type Item = usize;
fn next(&mut self) -> Option<usize> {
let &mut Trib(a, b, c) = self;
let d = a + b + c;
*self = Trib(b, c, d);
Some(d)
}
}
fn options(stairs: usize) -> usize {
Trib::default().take(stairs + 1).last().unwrap()
}
fn main() {
for (i, v) in Trib::default().enumerate().take(10) {
println!("i={}, t={}", i, v);
}
println!("{}", options(0));
println!("{}", options(1));
println!("{}", options(3));
println!("{}", options(7));
}
Playground
Your code looks pretty idiomatic to me, although #ljedrz has suggested an even more elegant rewriting of the same strategy.
Since this is an interview problem, it's worth mentioning that neither solution is going to be seen as an amazing answer because both solutions take exponential time in the number of stairs.
Here is what I might write if I were trying to crack a coding interview:
fn options(stairs: usize) -> u128 {
let mut o = vec![1, 1, 2, 4];
for _ in 3..stairs {
o.push(o[o.len() - 1] + o[o.len() - 2] + o[o.len() - 3]);
}
o[stairs]
}
Instead of recomputing options(n) each time, we cache each value in an array. So, this should run in linear time instead of exponential time. I also switched to a u128 to be able to return solutions for larger inputs.
Keep in mind that this is not the most efficient solution because it uses linear space. You can get away with using constant space by only keeping track of the final three elements of the array. I chose this as a compromise between conciseness, readability, and efficiency.

Is `iter().map().sum()` as fast as `iter().fold()`?

Does the compiler generate the same code for iter().map().sum() and iter().fold()? In the end they achieve the same goal, but the first code would iterate two times, once for the map and once for the sum.
Here is an example. Which version would be faster in total?
pub fn square(s: u32) -> u64 {
match s {
s # 1...64 => 2u64.pow(s - 1),
_ => panic!("Square must be between 1 and 64")
}
}
pub fn total() -> u64 {
// A fold
(0..64).fold(0u64, |r, s| r + square(s + 1))
// or a map
(1..64).map(square).sum()
}
What would be good tools to look at the assembly or benchmark this?
For them to generate the same code, they'd first have to do the same thing. Your two examples do not:
fn total_fold() -> u64 {
(0..64).fold(0u64, |r, s| r + square(s + 1))
}
fn total_map() -> u64 {
(1..64).map(square).sum()
}
fn main() {
println!("{}", total_fold());
println!("{}", total_map());
}
18446744073709551615
9223372036854775807
Let's assume you meant
fn total_fold() -> u64 {
(1..64).fold(0u64, |r, s| r + square(s + 1))
}
fn total_map() -> u64 {
(1..64).map(|i| square(i + 1)).sum()
}
There are a few avenues to check:
The generated LLVM IR
The generated assembly
Benchmark
The easiest source for the IR and assembly is one of the playgrounds (official or alternate). These both have buttons to view the assembly or IR. You can also pass --emit=llvm-ir or --emit=asm to the compiler to generate these files.
Make sure to generate assembly or IR in release mode. The attribute #[inline(never)] is often useful to keep functions separate to find them easier in the output.
Benchmarking is documented in The Rust Programming Language, so there's no need to repeat all that valuable information.
Before Rust 1.14, these do not produce the exact same assembly. I'd wait for benchmarking / profiling data to see if there's any meaningful impact on performance before I worried.
As of Rust 1.14, they do produce the same assembly! This is one reason I love Rust. You can write clear and idiomatic code and smart people come along and make it equally as fast.
but the first code would iterate two times, once for the map and once for the sum.
This is incorrect, and I'd love to know what source told you this so we can go correct it at that point and prevent future misunderstandings. An iterator operates on a pull basis; one element is processed at a time. The core method is next, which yields a single value, running just enough computation to produce that value.
First, let's fix those example to actually return the same result:
pub fn total_fold_iter() -> u64 {
(1..65).fold(0u64, |r, s| r + square(s))
}
pub fn total_map_iter() -> u64 {
(1..65).map(square).sum()
}
Now, let's develop them, starting with fold. A fold is just a loop and an accumulator, it is roughly equivalent to:
pub fn total_fold_explicit() -> u64 {
let mut total = 0;
for i in 1..65 {
total = total + square(i);
}
total
}
Then, let's go with map and sum, and unwrap the sum first, which is roughly equivalent to:
pub fn total_map_partial_iter() -> u64 {
let mut total = 0;
for i in (1..65).map(square) {
total += i;
}
total
}
It's just a simple accumulator! And now, let's unwrap the map layer (which only applies a function), obtaining something that is roughly equivalent to:
pub fn total_map_explicit() -> u64 {
let mut total = 0;
for i in 1..65 {
let s = square(i);
total += s;
}
total
}
As you can see, the both of them are extremely similar: they have apply the same operations in the same order and have the same overall complexity.
Which is faster? I have no idea. And a micro-benchmark may only tell half the truth anyway: just because something is faster in a micro-benchmark does not mean it is faster in the midst of other code.
What I can say, however, is that they both have equivalent complexity and therefore should behave similarly, ie within a factor of each other.
And that I would personally go for map + sum, because it expresses the intent more clearly whereas fold is the "kitchen-sink" of Iterator methods and therefore far less informative.

Resources