Insert a bit in an 8-bit integer - rust

I have a 5-bit u8, let's say 0b10101. I want to insert three bits, all ones, into positions 1, 2, and 4, to get: ii1i0101, i.e., 11110101. I want to accomplish this in three function calls, meaning that the function should take the index as one of the parameters and insert a single bit in that position.
I've come across this question, however, the answers on that page did not work for me. For example, the answer with the least upvotes panics when implemented, while others do not give the correct result.
fn insert_at(x: u8, index: u8, bit: u8) -> u8 {
let mask = (1 << (8 - index + 1)) - 1;
(x & !mask) | (bit << (8 - index)) | ((x & mask) >> 1)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[rstest(
input, index, expected,
case(0b10101, 1, 0b110101),
)]
fn test_bla(input: u8, index: u8, expected: u8) {
let result = insert_at(input, index, 1);
assert_eq!(result, expected);
}
}
thread 'tests::test_bla::case_1' panicked at 'attempt to shift left with overflow'

I've made a few assumptions (and a modification) to make the semantics of your question a bit more concrete:
The result of the operation must fit in 8 bits; if it does not, no value is returned.
index 0 (rather than index 1) refers to the position to the left of the most significant set bit.
A bit inserted in any index > 0 shifts all more significant set bits to the left by 1.
A working implementation (playground link):
fn insert_at(x: u8, index: u8, bit: u8) -> Option<u8> {
if bit != 0 && bit != 1 {
return None;
}
// most significant set bit, from right
// 0b00010101
// 87654321
let msb = 8 - x.leading_zeros() as u8;
if index >= msb {
// the new bit is past the LSB
return None;
}
// an index of 0 means insert as the new MSB (if it fits).
// insert is the absolute index of the inserted bit.
let insert = msb - index;
if insert == 8 {
// the new bit is out of u8 range.
// 0b_11111111
// ^ trying to insert here
return None;
}
let mask = (1 << insert) - 1;
Some((x & !mask) << 1 | (bit << insert) | (x & mask))
}
The reason your implementation panics is that Rust's left-shift operation is checked: if any bits would be shifted off the left side of the integer, the check fails. The main reasoning for this is that different platforms have different behavior in this case.
Rust also provides arithmetic operations with specified behavior in these cases, such as overflowing_shl() and wrapping_shl().

Related

BigUint binary complement

Example code:
use num_bigint::BigUint;
use num_traits::identities::One;
fn main() {
// Example: 10001 (17) => 1110 (14)
let n = BigUint::from(17u32);
println!("{}", n);
// BigUint doesn't support `!`
//let n = !n;
let mask = (BigUint::one() << n.bits()) - 1u32;
let n = n ^ mask;
println!("{}", n);
}
The above code is doing a binary complement of a BigUint using a bit mask. Questions:
Is there a better way to do the binary complement than with a mask? It seems BigUint doesn't include the ! operator (but the mask may be necessary anyway depending on how ! was defined).
If not is there a better way to generate the mask? (Caching masks helps, but can use lots of memory.)
More context with the problem I'm actually looking at: binary complement sequences
If you alternate multiplying by 3 and bit flipping a number some interesting sequences arise. Example starting with 3:
0. 3 (11b) => 3*3 = 9 (1001b) => bit complement is 6 (0110b)
1. 6 (110b)
2. 13 (1101b)
3. 24 (11000b)
4. 55 (110111b)
5. 90 (1011010b)
6. 241 (11110001b)
7. 300 (100101100b)
8. 123 (1111011b)
9. 142 (10001110b)
11. 85 (1010101b)
12. 0 (0b)
One question is whether it reaches zero for all starting numbers or not. Some meander around for quite a while before reaching zero (425720 takes 87,037,147,316 iterations to reach 0). Being able to compute this efficiently can help in answering these questions. Mostly I'm learning a bit more rust with this though.
If you are looking for performance, num-bigint probably isn't the best choice. Everything that is really high-performance, though, seems to be GPL licensed.
Either way, here is a solution using the rug library, which directly supports !(not), and seems to be really fast:
use rug::{ops::NotAssign, Integer};
fn main() {
// Example: 10001 (17) => 1110 (14)
let mut n = Integer::from(17u32);
println!("{}", n);
n.not_assign();
n.keep_bits_mut(n.significant_bits() - 1);
println!("{}", n);
}
17
14
Note that not_assign also inverts the sign bit. We can remove that bit through the keep_bits_mut function.
For example, here is a version of your algorithm:
use rug::{ops::NotAssign, Integer};
fn step(n: &mut Integer) {
*n *= 3;
n.not_assign();
n.keep_bits_mut(n.significant_bits() - 1);
}
fn main() {
let mut n = Integer::from(3);
println!("{}", n);
while n != 0 {
step(&mut n);
println!("{}", n);
}
}
3
6
13
24
55
90
241
300
123
142
85
0
The best solution is probably to just do it yourself. You perform an allocation each time you create a BigUint which really slows down your program. Since we are not doing complex math, we can simplify most of this to a couple bitwise operations.
After a little bit of tinkering, here is how I implemented it. For convenience, I used the unstable nightly feature bigint_helper_methods to allow for the carrying_add function. This helped simplify the addition process.
#[derive(Debug)]
pub struct BigUintHelper {
words: Vec<u64>,
}
impl BigUintHelper {
pub fn mul3_invert(&mut self) {
let len = self.words.len();
// Multiply everything by 3 by adding it to itself with a bit shift
let mut carry = false;
let mut prev_bit = 0;
for word in &mut self.words[..len - 1] {
let previous = *word;
// Perform the addition operation
let (next, next_carry) = previous.carrying_add((previous << 1) | prev_bit, carry);
// Reset carried values for next round
prev_bit = previous >> (u64::BITS - 1);
carry = next_carry;
// Invert the result as we go to avoid needing another pass
*word = !next;
}
// Perform the last word seperatly since we may need to do the invert differently
let previous = self.words[len - 1];
let (next, next_carry) = previous.carrying_add((previous << 1) | prev_bit, carry);
// Extra word from the combination of the carry bits
match next_carry as u64 + (previous >> (u64::BITS - 1)) {
0 => {
// The carry was 0 so we do the normal process
self.words[len - 1] = invert_bits(next);
self.cleanup_end();
}
1 => {
self.words[len - 1] = !next;
// invert_bits(1) = 0
self.cleanup_end();
}
2 => {
self.words[len - 1] = !next;
// invert_bits(2) = 1
self.words.push(1);
}
_ => unreachable!(),
}
}
/// Remove any high order words without any bits
#[inline(always)]
fn cleanup_end(&mut self) {
while let Some(x) = self.words.pop() {
if x != 0 {
self.words.push(x);
break;
}
}
}
/// Count how many rounds it takes to convert this value to 0.
pub fn count_rounds(&mut self) -> u64 {
let mut rounds = 0;
while !self.words.is_empty() {
self.mul3_invert();
rounds += 1;
}
rounds
}
}
impl From<u64> for BigUintHelper {
fn from(x: u64) -> Self {
BigUintHelper {
words: vec![x],
}
}
}
#[inline(always)]
const fn invert_bits(x: u64) -> u64 {
match x.leading_zeros() {
0 => !x,
y => ((1u64 << (u64::BITS - y)) - 1) ^ x
}
}
Rust Playground

How to find most and least significant bits of an unsigned number u64 in Rust?

Where LSB at index 0 and MSB is at index 63. Similarly it should extend to u32 and other types.
let my_num: u64 = 100; // 0b1100100
let msb = get_msb(my_num); // 0
let lsb = get_lsb(my_num); // 0
Correction: MSB should be 0 at the 63rd bit, not 1 at index 6
As explained in the comments, you get the LSB and MSB of u64 with n & 1 and (n >> 63) & 1 respectively.
Doing it completely generically is somewhat of a hassle in Rust, though, because generics require operations like shifting, masking, and even the construction of 1 to be fully specified upfront. However, this is where the num-traits crate comes to the rescue. Along with its cousin num, it is the de facto standard for generic Rust in the field of numerics, providing (among others) the PrimInt trait that makes get_msb() and get_lsb() straightforward:
use num_traits::PrimInt;
pub fn get_lsb<N: PrimInt>(n: N) -> N {
n & N::one()
}
pub fn get_msb<N: PrimInt>(n: N) -> N {
let shift = std::mem::size_of::<N>() * 8 - 1;
(n >> shift) & N::one()
}
fn main() {
assert_eq!(get_lsb(100u32), 0);
assert_eq!(get_lsb(101u32), 1);
assert_eq!(get_msb(100u32), 0);
assert_eq!(get_msb(u32::MAX), 1);
}
Playground

Getting the length of an int

I am trying to get the length (the number of digits when interpreted in decimal) of an int in rust. I found a way to do it, however am looking for method that comes from the primitive itself. This is what I have:
let num = 90.to_string();
println!("num: {}", num.chars().count())
// num: 2
I am looking at https://docs.rs/digits/0.3.3/digits/struct.Digits.html#method.length. is this a good candidate? How do I use it? Or are there other crates that does it for me?
One liners with less type conversion is the ideal solution I am looking for.
You could loop and check how often you can divide the number by 10 before it becomes a single digit.
Or in the other direction (because division is slower than multiplication), check how often you can multiply 10*10*...*10 until you reach the number:
fn length(n: u32, base: u32) -> u32 {
let mut power = base;
let mut count = 1;
while n >= power {
count += 1;
if let Some(new_power) = power.checked_mul(base) {
power = new_power;
} else {
break;
}
}
count
}
With nightly rust (or in the future, when the int_log feature is stabilized), you can use:
#![feature(int_log)]
n.checked_log10().unwrap_or(0) + 1
Here is a one-liner that doesn't require strings or floating point:
println!("num: {}", successors(Some(n), |&n| (n >= 10).then(|| n / 10)).count());
It simply counts the number of times the initial number needs to be divided by 10 in order to reach 0.
EDIT: the first version of this answer used iterate from the (excellent and highly recommended) itertools crate, but #trentcl pointed out that successors from the stdlib does the same. For reference, here is the version using iterate:
println!("num: {}", iterate(n, |&n| n / 10).take_while(|&n| n > 0).count().max(1));
Here's a (barely) one-liner that's faster than doing a string conversion, using std::iter stuff:
let some_int = 9834;
let decimal_places = (0..).take_while(|i| 10u64.pow(*i) <= some_int).count();
The first method below relies on the following formula, where a and b are the logarithmic bases.
log<a>( x ) = log<b>( x ) / log<b>( a )
log<a>( x ) = log<2>( x ) / log<2>( a ) // Substituting 2 for `b`.
The following function can be applied to finding the number of digits for bases that are a power of 2. This approach is very fast.
fn num_digits_base_pow2(n: u64, b: u32) -> u32
{
(63 - n.leading_zeros()) / (31 - b.leading_zeros()) + 1
}
The bits are counted for both n (the number we want to represent) and b (the base) to find their log2 floor values. Then the adjusted ratio of these values gives the ceiling log value in the desired base.
For a general purpose approach to finding the number of digits for arbitrary bases, the following should suffice.
fn num_digits(n: u64, b: u32) -> u32
{
(n as f64).log(b as f64).ceil() as u32
}
if num is signed:
let digits = (num.abs() as f64 + 0.1).log10().ceil() as u32;
A nice property of numbers that is always good to have in mind is that the number of digits required to write a number $x$ in base $n$ is actually $\lceil log_n(x + 1) \rceil$.
Therefore, one can simply write the following function (notice the cast from u32 to f32, since integers don't have a log function).
fn length(n: u32, base: u32) -> u32 {
let n = (n+1) as f32;
n.log(base as f32).ceil() as u32
}
You can easily adapt it for negative numbers. For floating point numbers this might be a bit (i.e. a lot) more tricky.
To take into account Daniel's comment about the pathological cases introduced by using f32, note that, with nightly Rust, integers have a logarithm method. (Notice that, imo, those are implementation details, and you should more focus on understanding the algorithm than the implementation.):
#![feature(int_log)]
fn length(n: u32, base: u32) -> u32 {
n.log(base) + 1
}

How to determine the rounding direction when casting a double to a float?

I'm looking for an algorithm to determine the rounding direction when casting an arbitrary 64-bit double to a 32-bit float. My specific use-case for this check is to cast a 64-bit double to 32-bit float with rounding toward infinity.
At first, I applied the following criteria to check for bits in the truncated portion of the mantissa. If the truncated portion of the mantissa is nonzero, then the cast must have rounded down!
const F64_MASK: u64 = (1 << 29) - 1;
fn is_rounded_up_when_casted(v: f64) -> bool {
v.to_bits() & F64_MASK > 0
}
This criteria doesn't, however, identify all cases-- the last three bits of the exponent are also truncated. I tried modifying the mask to check these exponent bits as well:
const F64_MASK: u64 = (1u64 << 55) - (1 << 52) + (1 << 29) - 1;
Unfortunately this check doesn't work. For example the number 1.401298464324817e−45 has an exponent in which the three truncated bits are 010 and yet is still exactly represented in float/f32.
EDIT: I don't think it can be said that a nonzero mantissa means positive rounding. I think I need a different approach. I think the exponent just increases the range of numbers, so that can be handled with some separate checks. The rounding direction may just be a function of the leading bit of the truncated portion of the mantissa?
The edge case that you found has to do with the fact that subnormal f32 values can actually represent exponents less than their typical minimum. I've written a function that I believe covers all the edge cases:
const F64_MANTISSA_SIZE: u64 = 52;
const F64_MANTISSA_MASK: u64 = (1 << F64_MANTISSA_SIZE) - 1;
const F64_EXPONENT_SIZE: u64 = 64 - F64_MANTISSA_SIZE - 1;
const F64_EXPONENT_MASK: u64 = (1 << F64_EXPONENT_SIZE) - 1; // shift away the mantissa first
const F32_MANTISSA_SIZE: u64 = 23;
const F64_TRUNCATED_MANTISSA_SIZE: u64 = F64_MANTISSA_SIZE - F32_MANTISSA_SIZE;
const F64_TRUNCATED_MANTISSA_MASK: u64 = (1 << F64_TRUNCATED_MANTISSA_SIZE) - 1;
fn is_exactly_representable_as_f32(v: f64) -> bool {
let bits = v.to_bits();
let mantissa = bits & F64_MANTISSA_MASK;
let exponent = (bits >> F64_MANTISSA_SIZE) & F64_EXPONENT_MASK;
let _sign = bits >> (F64_MANTISSA_SIZE + F64_EXPONENT_SIZE) != 0;
if exponent == 0 {
// if mantissa == 0, the float is 0 or -0, which is representable
// if mantissa != 0, it's a subnormal, which is never representable
return mantissa == 0;
}
if exponent == F64_EXPONENT_MASK {
// either infinity or nan, all of which are representable
return true;
}
// remember to subtract the bias
let computed_exponent = exponent as i64 - 1023;
// -126 and 127 are the min and max value for a standard f32 exponent
if (-126..=127).contains(&computed_exponent) {
// at this point, it's only exactly representable if the truncated mantissa is all zero
return mantissa & F64_TRUNCATED_MANTISSA_MASK == 0;
}
// exponents less than 2**(-126) may be representable by f32 subnormals
if computed_exponent < -126 {
// this is the number of leading zeroes that need to be in the f32 mantissa
let diff = -127 - computed_exponent;
// this is the number of bits in the mantissa that must be preserved (essentially mantissa with trailing zeroes trimmed off)
let mantissa_bits = F64_MANTISSA_SIZE - (mantissa.trailing_zeros() as u64).min(F64_MANTISSA_SIZE) + 1;
// the leading zeroes + essential mantissa bits must be able to fit in the smaller mantissa size
return diff as u64 + mantissa_bits <= F32_MANTISSA_SIZE;
}
// the exponent is >127 so f32s can't go that high
return false;
}
No need for mangling bits:
#[derive(PartialEq, std::fmt::Debug)]
enum Direction { Equal, Up, Down }
fn get_rounding_direction(v: f64) -> Direction {
match v.partial_cmp(&(v as f32 as f64)) {
Some(Ordering::Greater) => Direction::Down,
Some(Ordering::Less) => Direction::Up,
_ => Direction::Equal
}
}
And some tests to check correctness.
#[cfg(test)]
#[test]
fn test_get_rounding_direction() {
// check that the f64 one step below 2 casts to exactly 2
assert_eq!(get_rounding_direction(1.9999999999999998), Direction::Up);
// check edge cases
assert_eq!(get_rounding_direction(f64::NAN), Direction::Equal);
assert_eq!(get_rounding_direction(f64::NEG_INFINITY), Direction::Equal);
assert_eq!(get_rounding_direction(f64::MIN), Direction::Down);
assert_eq!(get_rounding_direction(-f64::MIN_POSITIVE), Direction::Up);
assert_eq!(get_rounding_direction(-0.), Direction::Equal);
assert_eq!(get_rounding_direction(0.), Direction::Equal);
assert_eq!(get_rounding_direction(f64::MIN_POSITIVE), Direction::Down);
assert_eq!(get_rounding_direction(f64::MAX), Direction::Up);
assert_eq!(get_rounding_direction(f64::INFINITY), Direction::Equal);
// for all other f32
for u32_bits in 1..f32::INFINITY.to_bits() - 1 {
let f64_value = f32::from_bits(u32_bits) as f64;
let u64_bits = f64_value.to_bits();
if u32_bits % 100_000_000 == 0 {
println!("checkpoint every 600 million tests: {}", f64_value);
}
// check that the f64 equivalent to the current f32 casts to a value that is equivalent
assert_eq!(get_rounding_direction(f64_value), Direction::Equal, "at {}, {}", u32_bits, f64_value);
// check that the f64 one step below the f64 equivalent to the current f32 casts to a value that is one step greater
assert_eq!(get_rounding_direction(f64::from_bits(u64_bits - 1)), Direction::Up, "at {}, {}", u32_bits, f64_value);
// check that the f64 one step above the f64 equivalent to the current f32 casts to a value that is one step less
assert_eq!(get_rounding_direction(f64::from_bits(u64_bits + 1)), Direction::Down, "at {}, {}", u32_bits, f64_value);
// same checks for negative numbers
let u64_bits = (-f64_value).to_bits();
assert_eq!(get_rounding_direction(f64_value), Direction::Equal, "at {}, {}", u32_bits, f64_value);
assert_eq!(get_rounding_direction(f64::from_bits(u64_bits - 1)), Direction::Down, "at {}, {}", u32_bits, f64_value);
assert_eq!(get_rounding_direction(f64::from_bits(u64_bits + 1)), Direction::Up, "at {}, {}", u32_bits, f64_value);
}
}
To specifically cast with rounding towards infinity:
fn cast_toward_inf(vf64: f64) -> f32 {
let vf32 = vf64 as f32;
if vf64 > vf32 as f64 { f32::from_bits(vf32.to_bits() + 1) } else { vf32 }
}
It's possible to determine the rounding primarily from the 28th bit (first bit of the truncated portion of the mantissa), but handling edge cases introduces significant complexity.

Literal out of range warning when iterating over all values of u8 [duplicate]

This question already has answers here:
How to iterate over all byte values (overflowing_literals in `0..256`)
(2 answers)
Closed 5 years ago.
The range in a for-loop, as I understand, is lower-limit inclusive and upper-limit exclusive. This is introducing an issue in the following code:
fn main() {
let a: u8 = 4;
for num in 0..256 {
if num == a {
println!("match found");
break;
}
}
}
I want to loop 256 times from 0 to 255, and this fits into the range of data supported by u8. But since the range is upper limit exclusive, I have to give 256 as the limit to process 255. Due to this, the compiler gives the following warning.
warning: literal out of range for u8
--> src/main.rs:4:19
|
4 | for num in 0..256 {
| ^^^
|
= note: #[warn(overflowing_literals)] on by default
When I execute this, the program skips the for loop.
In my opinion, the compiler has to ignore 256 in the range and accept the range as u8 range. Is it correct? Is there any other way to give the range?
You can combine iterators to create full u8 range:
use std::iter::once;
for i in (0..255).chain(once(255)){
//...
}
In nightly Rust you can use inclusive range:
#![feature(inclusive_range_syntax)]
for i in 0...255 {
//...
}
As alex already said, it's probably the best to iterate using bigger integer types (like u32) and cast it when comparing to the u8 you're searching for:
let a: u8 = 4;
for num in 0..256 {
if (num as u8) == a {
println!("match found");
break;
}
}
In this special case you can use a half-open range, too:
let a: u8 = 4;
for num in 0.. {
if num == a {
println!("match found");
break;
}
}
Also when I execute this, the program skips the for loop.
The binary representation of 256 is 1_0000_0000. u8 only saves the 8 rightmost bits, so just 0s. Thus 0..256u8 is equivalent to 0..0, which of course is an empty range.
I think you are comparing different types so need to cast. Try this:
for num in 0..256 {
let y = num as u8;
if y == a {
println!("found");
break;
}
}

Resources