Rustacean way to parse a date from &str [Rust] [closed] - rust

Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.
We don’t allow questions seeking recommendations for books, tools, software libraries, and more. You can edit the question so it can be answered with facts and citations.
Closed 1 year ago.
Improve this question
I have this function, but I feel that it's duplicate code. I was wondering if anyone here shares how to make it more rustacean. I'm still learning Rust, and I thought this could be a good example to share.
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let start_date_parsed: i64 = Utc
.ymd(
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[0]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[1]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[2]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
let end_date_parsed: i64 = Utc
.ymd(
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[0]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[1]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[2]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
)
.and_hms_milli(0, 0, 2, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
(start_date_parsed, end_date_parsed)
Mainly to remove the three arguments passed to Utc.ymd since they are doing the same, just using a different index, they are parsing dates such as "2021-01-01" and returning it in milliseconds and clamping it to a floor and ceiling.

Can we call "Don't reinvent the wheel" a Rustacean way? Probably..
Playground
You can use NaiveDate::parse_from_str, with a lot of formatting options.
You can then use Utc::from_utc_date or Utc::from_local_date to obtain the Date and process it later like you did
This reduces the code to:
use chrono::*;
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let start_date = NaiveDate::parse_from_str(start_date, "%Y-%m-%d")
.expect(&message);
let start_date_parsed: i64 = Utc.from_utc_date(&start_date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
let end_date = NaiveDate::parse_from_str(end_date, "%Y-%m-%d")
.expect(&message);
let end_date_parsed: i64 = Utc.from_utc_date(&end_date)
.and_hms_milli(0, 0, 2, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
(start_date_parsed, end_date_parsed)
}
That's where the question of code design comes, and it's opinion based. You see - you repeat the same code twice for processing start and end date - this code ideally should become a separate function.
use chrono::*;
fn parse_and_transform_date(date_str: &str) -> Result<i64, format::ParseError> {
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")?;
let date_parsed: i64 = Utc.from_utc_date(&date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
Ok(date_parsed)
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let start_date_result = parse_and_transform_date(start_date);
let end_date_result = parse_and_transform_date(end_date);
if let (Ok(start_date), Ok(end_date)) = (start_date_result, end_date_result) {
return (start_date, end_date); // success
}
panic!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
}
Finally I would also let you consider a few problems:
It's not a good taste to abort your whole program with an error message like that, I would suggest check_and_transform_dates to instead return a Result and then considering this result, your calling code should handle the situation properly.
check_and_transform_dates should also probably do additional checks, e.g. check that end_date is not before start_date, etc
Example of fixing:
use chrono::*;
fn parse_and_transform_date(date_str: &str) -> Result<i64, format::ParseError> {
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")?;
let date_parsed: i64 = Utc.from_utc_date(&date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
Ok(date_parsed)
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> Option<(i64, i64)> {
let start_date_result = parse_and_transform_date(start_date);
let end_date_result = parse_and_transform_date(end_date);
match (start_date_result, end_date_result) {
(Ok(s), Ok(e)) if (s <= e) => Some((s, e)),
_ => None
}
}
fn main() {
println!("{:?}", check_and_transform_dates("2020-01-02", "2020-01-03")
.expect("Data could not be downloaded"));
}

I'm going with this refactoring
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let earliest: NaiveDate = NaiveDate::from_ymd(2016, 1, 1);
let today: NaiveDate = Utc::today().naive_utc();
let parse_date = |date: &str| -> NaiveDate {
let date: NaiveDate = NaiveDate::parse_from_str(date, "%F").unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
});
if date < earliest {
earliest
} else if date > today {
today
} else {
date
}
};
(
parse_date(start_date).and_hms(0, 0, 1).timestamp() * 1000,
parse_date(end_date).and_hms(0, 0, 2).timestamp() * 1000,
)
}

make a function for checking and clamping
fn check_and_clamp_date() is needed for start_date and end_date
extern crate chrono;
use chrono::{Duration, NaiveDate, Utc};
static FMT: &str = "%Y-%m-%d";
fn check_and_clamp_date(date: &str, ceil_date: &str, floor_date: &str) -> Result<i64, String> {
let ceil = NaiveDate::parse_from_str(ceil_date, FMT).expect("start corrupt");
let floor = NaiveDate::parse_from_str(floor_date, FMT).expect("end corrupt");
let r = NaiveDate::parse_from_str(date, FMT);
match r {
Err(_) => Err("invalid date format".to_string()),
Ok(d) => {
let delta: Duration = d.clamp(ceil, floor).signed_duration_since(NaiveDate::from_ymd(1970, 1, 1));
Ok(delta.num_nanoseconds().unwrap() / 1000000)
},
}
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let now = Utc::now().format(FMT).to_string();
let start = check_and_clamp_date(start_date, "2016-01-01", &now).expect(&message) + 1000;
let end = check_and_clamp_date(end_date, "2016-01-01", &now).expect(&message) + 2000;
(start, end)
}

Related

Text to Number Conversion

I'm trying to create a Rust version of the accepted solution to this question, which is to convert a string such as "two hundred fifty eight" into 258.
I created the code below to match the pseudocode given, but somehow it's unable to get the hundreds and thousands right. My example string returns 158 instead of 258, for instance. What am I missing?
let _units = vec![ "zero", "one", "two", "three", "four", "five", "six", "seven",
"eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
"sixteen", "seventeen", "eighteen", "nineteen", ];
let mut numbers: Vec<(&str,u32)> = _units.iter().enumerate().map(|(idx, n)| (*n, idx as u32)).collect();
let _tens = vec!["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"];
let mut tens: Vec<(&str,u32)> = _tens.iter().enumerate().map(|(idx, n)| (*n, (idx * 10) as u32)).collect();
let _scales = vec!["hundred", "thousand", "million", "billion"];
let base:i32 = 1000;
let mut scales: Vec<(&str,u32)> = _scales.iter().enumerate().map(|(idx, n)| (*n, base.pow(idx as u32) as u32)).collect();
numbers.append(&mut tens);
numbers.append(&mut scales);
use std::collections::HashMap;
fn text_to_int(textnum: &str, numwords: HashMap<&str, u32>) -> u32 {
let mut result = 0;
let mut prior = 0;
for word in textnum.split(" ") {
let value = numwords[word];
if prior == 0 {
prior = value;
} else if prior > value {
prior += value;
} else {
prior *= value;
};
if value > 100 && prior != 0 {
result += prior;
prior = 0;
};
}
return result + prior;
}
let numwords: HashMap<_, _> = numbers.into_iter().collect();
let textnum = "two hundred fifty eight thousand";
println!("{:?}", text_to_int(textnum, numwords));
Returns 158000. What am I doing wrong?
your problem is with this line
let base:i32 = 1000;
let mut scales: Vec<(&str,u32)> = _scales.iter().enumerate().map(|(idx, n)| (*n, base.pow(idx as u32) as u32)).collect();
because "hundred" give you 1 not 100
with your algorithm you can't produce correct numbers for scales
"hundred" -> 100 -> 10^2
"thousand" -> 1_000 -> 10^3
"million" -> 1_000_000 -> 10^6
"billion" -> 1_000_000_000 -> 10^9
simple way to fix this is add "hundred" manually
let _scales = vec!["thousand", "million", "billion"];
let base:i32 = 1000;
let mut scales: Vec<(&str,u32)> = _scales.iter().enumerate().map(|(idx, n)| (*n, base.pow(idx as u32 + 1) as u32)).collect();
numbers.append(&mut tens);
numbers.append(&mut scales);
numbers.push(("hundred", 100));
[Edit]
Alternative way to solve this question is work with FromStr that will give you advantages of parse() function
for do this you can use following code
first create a custom type for holding value such as u32
#[derive(Debug)]
struct Number<T>(T);
its' good if you impl Display too (Optional)
impl<T: Display> Display for Number<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
now create text number map
const BASIC: [(&'static str, u32); 32] = [
("zero", 0),
("one", 1),
("two", 2),
("three", 3),
("four", 4),
("five", 5),
("six", 6),
("seven", 7),
("eight", 8),
("nine", 9),
("ten", 10),
("eleven", 11),
("twelve", 12),
("thirteen", 13),
("fourteen", 14),
("fifteen", 15),
("sixteen", 16),
("seventeen", 17),
("eighteen", 18),
("nineteen", 19),
("twenty", 20),
("thirty", 30),
("forty", 40),
("fifty", 50),
("sixty", 60),
("seventy", 70),
("eighty", 80),
("ninety", 90),
("hundred", 100),
("thousand", 1000),
("million", 1000000),
("billion", 1000000000),
];
now just place your text_to_int function codes in FromStr function like this
impl FromStr for Number<u32> {
type Err = ();
fn from_str(textnum: &str) -> Result<Self, Self::Err> {
let textnum = textnum.to_lowercase();
let numwords: HashMap<_, _> = BASIC.into_iter().collect();
let mut result = 0u32;
let mut prior = 0u32;
for word in textnum.split(" ") {
let value = numwords[word].into();
if prior == 0 {
prior = value;
}
else if prior > value {
prior += value;
}
else {
prior *= value;
}
if value > 100 && prior != 0 {
result += prior;
prior = 0;
}
}
Ok(Number(result + prior))
}
}
now you can parse every &str like this
fn main() {
let num: Number<u32> = "two hundred fifty eight thousand".parse().unwrap();
println!("{}", num);
}

traversing a grid up or down without having to write the same loop twice

I'm working on moving in a vector that acts as a representation of a 2d grid, and I want to be able to move straight up or down, but don't want to rewrite the for loop twice. I came up with this solution that fails due to the borrow checker when indexing into the grid.
// fill the vec later
let grid: Vec<u8> = Vec::with_capacity(WIDTH*HEIGHT);
let mut current_position = Point::new(somewhere);
let end = Point::new(somewhere_else);
let moving: &mut usize;
let to: u32;
if moving_horizontal {
moving = &mut current_position.x;
to = end.x;
}
else {
moving = &mut current_position.y;
to = end.y;
}
for _ in *moving..=to {
// do stuff with this
grid[current_position.x+current_position.y*HEIGHT];
*moving += 1;
}
Is there any neat solution to this or do I have to just write the same for loop twice in each block of my conditional statements?
Please don't too harsh on me. I've just started a few days ago to learn Rust. This is just an idea I had and I'm posting it to learn if it works in this situation or not.
let mut current_position = Point::new(somewhere);
let end = Point::new(somewhere_else);
let moving: &mut usize;
let not_moving: &usize;
let to: u32;
if moving_horizontal {
moving = &mut current_position.x;
not_moving = &current_position.y;
to = end.x
} else {
moving = &mut current_position.y;
not_moving = &current_position.x;
to = end.y
}
for _ in *moving..=to {
// do stuff with this
if moving_horizontal {
grid[*moving + *not_moving * HEIGHT];
} else {
grid[*not_moving + *moving * HEGHT];
}
*moving += 1;
}
If we didn't have to use moving in the for declaration, it could be written in a much better way, but I'm not sure about the details of the algorithm.
Anyway, even if this passes the borrow checker, it is very difficult to read and a double for loop would be much more maintainable.
I think you can generate all the move steps first and then change the current point with the steps one by one in one single loop.
There are four kinds of steps -- (-1, 0), (1, 0), (0, 1) and (0, -1), corresponding to moving left, right, up and down. You can calculate how many steps are needed and which direction needs to be taken first based on the relative positions of the curr and end points and the moving_horizontal flag.
use std::iter::repeat;
struct Point {
x: usize,
y: usize,
}
fn main() {
const HEIGHT: usize = 25;
const WIDTH: usize = 80;
const LEFT: (i8, i8) = (-1, 0);
const RIGHT: (i8, i8) = (1, 0);
const UP: (i8, i8) = (0, 1);
const DOWN: (i8, i8) = (0, -1);
let grid: Vec<u8> = vec![0; WIDTH * HEIGHT];
let mut curr = Point {x: 6, y: 2 };
let end = Point {x: 3, y: 4 };
let (x_step, x_num) = if end.x > curr.x { (RIGHT, end.x - curr.x) }
else { (LEFT, curr.x - end.x) };
let (y_step, y_num) = if end.y > curr.y { (UP, end.y - curr.y) }
else { (DOWN, curr.y - end.y) };
let moving_horizontal = true;
let all_steps = if moving_horizontal {
repeat(x_step).take(x_num).chain(repeat(y_step).take(y_num))
} else {
repeat(y_step).take(y_num).chain(repeat(x_step).take(x_num))
};
println!("Start: (x:{}, y:{})", curr.x, curr.y);
println!("End: (x:{}, y:{})", end.x, end.y);
println!("Steps (Horizontal: {}):", moving_horizontal);
for (i, step) in all_steps.enumerate() {
// do stuff with this
grid[curr.x + curr.y * HEIGHT];
curr.x = match step.0 {
1 => curr.x + 1,
-1 => curr.x - 1,
_ => curr.x
};
curr.y = match step.1 {
1 => curr.y + 1,
-1 => curr.y -1,
_ => curr.y
};
println!("{} => (x:{}, y:{})", i, curr.x, curr.y);
}
}
Playground
The output when moving_horizontal is true:
Start: (x:6, y:2)
End: (x:3, y:4)
Steps (Horizontal: true):
0 => (x:5, y:2)
1 => (x:4, y:2)
2 => (x:3, y:2)
3 => (x:3, y:3)
4 => (x:3, y:4)
The output when moving_horizontal is false:
Start: (x:6, y:2)
End: (x:3, y:4)
Steps (Horizontal: false):
0 => (x:6, y:3)
1 => (x:6, y:4)
2 => (x:5, y:4)
3 => (x:4, y:4)
4 => (x:3, y:4)
If I understood correctly, you want to mutate the current_position so that it moves to the end Point horizontally or vertically.
fn main() {
moving_position(true);
moving_position(false);
}
#[derive(Debug)]
struct Point {
x: usize,
y: usize,
}
impl Point {
fn new(x: usize, y: usize) -> Self {
Self { x, y }
}
}
const WIDTH: usize = 30;
const HEIGHT: usize = 30;
fn moving_position(moving_horizontal: bool) {
let grid: Vec<u8> = Vec::with_capacity(WIDTH * HEIGHT);
let mut current_position = Point::new(8, 5);
let end = Point::new(0, 1);
println!("Starting at: {:?}", current_position);
match moving_horizontal {
true => move_horizontal(&mut current_position, end.x),
false => move_vertical(&mut current_position, end.y),
};
println!("Final position: {:?}",current_position);
}
fn move_horizontal(current_position: &mut Point, to: usize) {
match current_position.x <= to {
true => for i in current_position.x+1..=to {
current_position.x = i;
println!("moving right: {:?}", current_position);
},
false => for i in (to..current_position.x).rev() {
current_position.x = i;
println!("moving left: {:?}", current_position);
},
};
}
fn move_vertical(mut current_position: &mut Point, to: usize) {
match current_position.y <= to {
true => for i in current_position.y+1..=to {
current_position.y = i;
println!("moving up: {:?}", current_position);
},
false => for i in (to..current_position.y).rev() {
current_position.y = i;
println!("moving down: {:?}", current_position);
},
};
}
We check if its moving_horizontal and depending on that we call functions move_horizontal or move_vertical.
These two functions take the current_position as a mutable reference (meaning we will mutate current_position in the function but will return ownership).
Playground

Severe performance degredation over time in multi-threading: what am I missing?

In my application a method runs quickly once started but begins to continuously degrade in performance upon nearing completion, this seems to be even irrelevant of the amount of work (the number of iterations of a function each thread has to perform). Once it reaches near the end it slows to an incredibly slow pace compared to earlier (worth noting this is not just a result of fewer threads remaining incomplete, it seems even each thread slows down).
I cannot figure out why this occurs, so I'm asking. What am I doing wrong?
An overview of CPU usage:
A slideshow of the problem
Worth noting that CPU temperature remains low throughout.
This stage varies with however much work is set, more work produces a better appearance with all threads constantly near 100%. Still, at this moment this appears good.
Here we see the continued performance of earlier,
Here we see it start to degrade. I do not know why this occurs.
After some period of chaos most of the threads have finished their work and the remaining threads continue, at this point although it seems they are at 100% they in actually perform their remaining workload very slowly. I cannot understand why this occurs.
Printing progress
I have written a multi-threaded random_search (documentation link) function for optimization. Most of the complexity in this function comes from printing data passing data between threads, this supports giving outputs showing progress like:
2300
565 (24.57%) 00:00:11 / 00:00:47 [25.600657363049734] { [563.0ns, 561.3ms, 125.0ns, 110.0ns] [2.0µs, 361.8ms, 374.0ns, 405.0ns] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] }
I have been trying to use this output to figure out whats gone wrong, but I have no idea.
This output describes:
The total number of iterations 2300.
The total number of current iterations 565.
The time running 00:00:11 (mm:ss:ms).
The estimated time remaining 00:00:47 (mm:ss:ms).
The current best value [25.600657363049734].
The most recently measured times between execution positions (effectively time taken for thread to go from some line, to another line (defined specifically with update_execution_position in code below) [563.0ns, 561.3ms, 125.0ns, 110.0ns].
The averages times between execution positions (this is average across entire runtime rather than since last measured) [2.0µs, 361.8ms, 374.0ns, 405.0ns].
The execution positions of threads (0 is when a thread is completed, rest represent a thread having hit some line, which triggered this setting, but yet to hit next line which changes it, effectively being between 2 positions) [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
The random_search code:
Given I have tested implementations with the other methods in my library grid_search and simulated_annealing it would suggest to me the problem does not atleast entirely reside in random_search.rs.
random_search.rs:
pub fn random_search<
A: 'static + Send + Sync,
T: 'static + Copy + Send + Sync + Default + SampleUniform + PartialOrd,
const N: usize,
>(
// Generics
ranges: [Range<T>; N],
f: fn(&[T; N], Option<Arc<A>>) -> f64,
evaluation_data: Option<Arc<A>>,
polling: Option<Polling>,
// Specifics
iterations: u64,
) -> [T; N] {
// Gets cpu data
let cpus = num_cpus::get() as u64;
let search_cpus = cpus - 1; // 1 cpu is used for polling, this one.
let remainder = iterations % search_cpus;
let per = iterations / search_cpus;
let ranges_arc = Arc::new(ranges);
let (best_value, best_params) = search(
// Generics
ranges_arc.clone(),
f,
evaluation_data.clone(),
// Since we are doing this on the same thread, we don't need to use these
Arc::new(AtomicU64::new(Default::default())),
Arc::new(Mutex::new(Default::default())),
Arc::new(AtomicBool::new(false)),
Arc::new(AtomicU8::new(0)),
Arc::new([
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
]),
// Specifics
remainder,
);
let thread_exit = Arc::new(AtomicBool::new(false));
// (handles,(counters,thread_bests))
let (handles, links): (Vec<_>, Vec<_>) = (0..search_cpus)
.map(|_| {
let ranges_clone = ranges_arc.clone();
let counter = Arc::new(AtomicU64::new(0));
let thread_best = Arc::new(Mutex::new(f64::MAX));
let thread_execution_position = Arc::new(AtomicU8::new(0));
let thread_execution_time = Arc::new([
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
]);
let counter_clone = counter.clone();
let thread_best_clone = thread_best.clone();
let thread_exit_clone = thread_exit.clone();
let evaluation_data_clone = evaluation_data.clone();
let thread_execution_position_clone = thread_execution_position.clone();
let thread_execution_time_clone = thread_execution_time.clone();
(
thread::spawn(move || {
search(
// Generics
ranges_clone,
f,
evaluation_data_clone,
counter_clone,
thread_best_clone,
thread_exit_clone,
thread_execution_position_clone,
thread_execution_time_clone,
// Specifics
per,
)
}),
(
counter,
(
thread_best,
(thread_execution_position, thread_execution_time),
),
),
)
})
.unzip();
let (counters, links): (Vec<Arc<AtomicU64>>, Vec<_>) = links.into_iter().unzip();
let (thread_bests, links): (Vec<Arc<Mutex<f64>>>, Vec<_>) = links.into_iter().unzip();
let (thread_execution_positions, thread_execution_times) = links.into_iter().unzip();
if let Some(poll_data) = polling {
poll(
poll_data,
counters,
remainder,
iterations,
thread_bests,
thread_exit,
thread_execution_positions,
thread_execution_times,
);
}
let joins: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let (_, best_params) = joins
.into_iter()
.fold((best_value, best_params), |(bv, bp), (v, p)| {
if v < bv {
(v, p)
} else {
(bv, bp)
}
});
return best_params;
fn search<
A: 'static + Send + Sync,
T: 'static + Copy + Send + Sync + Default + SampleUniform + PartialOrd,
const N: usize,
>(
// Generics
ranges: Arc<[Range<T>; N]>,
f: fn(&[T; N], Option<Arc<A>>) -> f64,
evaluation_data: Option<Arc<A>>,
counter: Arc<AtomicU64>,
best: Arc<Mutex<f64>>,
thread_exit: Arc<AtomicBool>,
thread_execution_position: Arc<AtomicU8>,
thread_execution_times: Arc<[Mutex<(Duration, u64)>; 4]>,
// Specifics
iterations: u64,
) -> (f64, [T; N]) {
let mut execution_position_timer = Instant::now();
let mut rng = thread_rng();
let mut params = [Default::default(); N];
let mut best_value = f64::MAX;
let mut best_params = [Default::default(); N];
for _ in 0..iterations {
// Gen random values
for (range, param) in ranges.iter().zip(params.iter_mut()) {
*param = rng.gen_range(range.clone());
}
// Update execution position
execution_position_timer = update_execution_position(
1,
execution_position_timer,
&thread_execution_position,
&thread_execution_times,
);
// Run function
let new_value = f(&params, evaluation_data.clone());
// Update execution position
execution_position_timer = update_execution_position(
2,
execution_position_timer,
&thread_execution_position,
&thread_execution_times,
);
// Check best
if new_value < best_value {
best_value = new_value;
best_params = params;
*best.lock().unwrap() = best_value;
}
// Update execution position
execution_position_timer = update_execution_position(
3,
execution_position_timer,
&thread_execution_position,
&thread_execution_times,
);
counter.fetch_add(1, Ordering::SeqCst);
// Update execution position
execution_position_timer = update_execution_position(
4,
execution_position_timer,
&thread_execution_position,
&thread_execution_times,
);
if thread_exit.load(Ordering::SeqCst) {
break;
}
}
// Update execution position
// 0 represents ended state
thread_execution_position.store(0, Ordering::SeqCst);
return (best_value, best_params);
}
}
util.rs:
pub fn update_execution_position<const N: usize>(
i: usize,
execution_position_timer: Instant,
thread_execution_position: &Arc<AtomicU8>,
thread_execution_times: &Arc<[Mutex<(Duration, u64)>; N]>,
) -> Instant {
{
let mut data = thread_execution_times[i - 1].lock().unwrap();
data.0 += execution_position_timer.elapsed();
data.1 += 1;
}
thread_execution_position.store(i as u8, Ordering::SeqCst);
Instant::now()
}
pub struct Polling {
pub poll_rate: u64,
pub printing: bool,
pub early_exit_minimum: Option<f64>,
pub thread_execution_reporting: bool,
}
impl Polling {
const DEFAULT_POLL_RATE: u64 = 10;
pub fn new(printing: bool, early_exit_minimum: Option<f64>) -> Self {
Self {
poll_rate: Polling::DEFAULT_POLL_RATE,
printing,
early_exit_minimum,
thread_execution_reporting: false,
}
}
}
pub fn poll<const N: usize>(
data: Polling,
// Current count of each thread.
counters: Vec<Arc<AtomicU64>>,
offset: u64,
// Final total iterations.
iterations: u64,
// Best values of each thread.
thread_bests: Vec<Arc<Mutex<f64>>>,
// Early exit switch.
thread_exit: Arc<AtomicBool>,
// Current positions of execution of each thread.
thread_execution_positions: Vec<Arc<AtomicU8>>,
// Current average times between execution positions for each thread
thread_execution_times: Vec<Arc<[Mutex<(Duration, u64)>; N]>>,
) {
let start = Instant::now();
let mut stdout = stdout();
let mut count = offset
+ counters
.iter()
.map(|c| c.load(Ordering::SeqCst))
.sum::<u64>();
if data.printing {
println!("{:20}", iterations);
}
let mut poll_time = Instant::now();
let mut held_best: f64 = f64::MAX;
let mut held_average_execution_times: [(Duration, u64); N] =
vec![(Duration::new(0, 0), 0); N].try_into().unwrap();
let mut held_recent_execution_times: [Duration; N] =
vec![Duration::new(0, 0); N].try_into().unwrap();
while count < iterations {
if data.printing {
// loop {
let percent = count as f32 / iterations as f32;
// If count == 0, give 00... for remaining time as placeholder
let remaining_time_estimate = if count == 0 {
Duration::new(0, 0)
} else {
start.elapsed().div_f32(percent)
};
print!(
"\r{:20} ({:.2}%) {} / {} [{}] {}\t",
count,
100. * percent,
print_duration(start.elapsed(), 0..3),
print_duration(remaining_time_estimate, 0..3),
if held_best == f64::MAX {
String::from("?")
} else {
format!("{}", held_best)
},
if data.thread_execution_reporting {
let (average_execution_times, recent_execution_times): (
Vec<String>,
Vec<String>,
) = (0..thread_execution_times[0].len())
.map(|i| {
let (mut sum, mut num) = (Duration::new(0, 0), 0);
for n in 0..thread_execution_times.len() {
{
let mut data = thread_execution_times[n][i].lock().unwrap();
sum += data.0;
held_average_execution_times[i].0 += data.0;
num += data.1;
held_average_execution_times[i].1 += data.1;
*data = (Duration::new(0, 0), 0);
}
}
if num > 0 {
held_recent_execution_times[i] = sum.div_f64(num as f64);
}
(
if held_average_execution_times[i].1 > 0 {
format!(
"{:.1?}",
held_average_execution_times[i]
.0
.div_f64(held_average_execution_times[i].1 as f64)
)
} else {
String::from("?")
},
if held_recent_execution_times[i] > Duration::new(0, 0) {
format!("{:.1?}", held_recent_execution_times[i])
} else {
String::from("?")
},
)
})
.unzip();
let execution_positions: Vec<u8> = thread_execution_positions
.iter()
.map(|pos| pos.load(Ordering::SeqCst))
.collect();
format!(
"{{ [{}] [{}] {:.?} }}",
recent_execution_times.join(", "),
average_execution_times.join(", "),
execution_positions
)
} else {
String::from("")
}
);
stdout.flush().unwrap();
}
// Updates best and does early exiting
match (data.early_exit_minimum, data.printing) {
(Some(early_exit), true) => {
for thread_best in thread_bests.iter() {
let thread_best_temp = *thread_best.lock().unwrap();
if thread_best_temp < held_best {
held_best = thread_best_temp;
if thread_best_temp <= early_exit {
thread_exit.store(true, Ordering::SeqCst);
println!();
return;
}
}
}
}
(None, true) => {
for thread_best in thread_bests.iter() {
let thread_best_temp = *thread_best.lock().unwrap();
if thread_best_temp < held_best {
held_best = thread_best_temp;
}
}
}
(Some(early_exit), false) => {
for thread_best in thread_bests.iter() {
if *thread_best.lock().unwrap() <= early_exit {
thread_exit.store(true, Ordering::SeqCst);
return;
}
}
}
(None, false) => {}
}
thread::sleep(saturating_sub(
Duration::from_millis(data.poll_rate),
poll_time.elapsed(),
));
poll_time = Instant::now();
count = offset
+ counters
.iter()
.map(|c| c.load(Ordering::SeqCst))
.sum::<u64>();
}
if data.printing {
println!(
"\r{:20} (100.00%) {} / {} [{}] {}\t",
count,
print_duration(start.elapsed(), 0..3),
print_duration(start.elapsed(), 0..3),
held_best,
if data.thread_execution_reporting {
let (average_execution_times, recent_execution_times): (Vec<String>, Vec<String>) =
(0..thread_execution_times[0].len())
.map(|i| {
let (mut sum, mut num) = (Duration::new(0, 0), 0);
for n in 0..thread_execution_times.len() {
{
let mut data = thread_execution_times[n][i].lock().unwrap();
sum += data.0;
held_average_execution_times[i].0 += data.0;
num += data.1;
held_average_execution_times[i].1 += data.1;
*data = (Duration::new(0, 0), 0);
}
}
if num > 0 {
held_recent_execution_times[i] = sum.div_f64(num as f64);
}
(
if held_average_execution_times[i].1 > 0 {
format!(
"{:.1?}",
held_average_execution_times[i]
.0
.div_f64(held_average_execution_times[i].1 as f64)
)
} else {
String::from("?")
},
if held_recent_execution_times[i] > Duration::new(0, 0) {
format!("{:.1?}", held_recent_execution_times[i])
} else {
String::from("?")
},
)
})
.unzip();
let execution_positions: Vec<u8> = thread_execution_positions
.iter()
.map(|pos| pos.load(Ordering::SeqCst))
.collect();
format!(
"{{ [{}] [{}] {:.?} }}",
recent_execution_times.join(", "),
average_execution_times.join(", "),
execution_positions
)
} else {
String::from("")
}
);
stdout.flush().unwrap();
}
}
// Since `Duration::saturating_sub` is unstable this is an alternative.
fn saturating_sub(a: Duration, b: Duration) -> Duration {
if let Some(dur) = a.checked_sub(b) {
dur
} else {
Duration::new(0, 0)
}
}
main.rs
use std::{cmp,sync::Arc};
type Image = Vec<Vec<Pixel>>;
#[derive(Clone)]
pub struct Pixel {
pub luma: u8,
}
impl From<&u8> for Pixel {
fn from(x: &u8) -> Pixel {
Pixel { luma: *x }
}
}
fn main() {
// Setup
// -------------------------------------------
fn open_image(path: &str) -> Image {
let example = image::open(path).unwrap().to_rgb8();
let dims = example.dimensions();
let size = (dims.0 as usize, dims.1 as usize);
let example_vec = example.into_raw();
// Binarizes image
let img_vec = from_raw(&example_vec, size);
img_vec
}
println!("Started ...");
let example: Image = open_image("example.jpg");
let target: Image = open_image("target.jpg");
// let first_image = Some(Arc::new((examples[0].clone(), targets[0].clone())));
println!("Opened...");
let image = Some(Arc::new((example, target)));
// Running the optimization
// -------------------------------------------
println!("Started opt...");
let best = simple_optimization::random_search(
[0..255, 0..255, 0..255, 1..255, 1..255],
eval_one,
image,
Some(simple_optimization::Polling {
poll_rate: 100,
printing: true,
early_exit_minimum: None,
thread_execution_reporting: true,
}),
2300,
);
println!("{:.?}", best); // [34, 220, 43, 253, 168]
assert!(false);
fn eval_one(arr: &[u8; 5], opt: Option<Arc<(Image, Image)>>) -> f64 {
let bin_params = (
arr[0] as u8,
arr[1] as u8,
arr[2] as u8,
arr[3] as usize,
arr[4] as usize,
);
let arc = opt.unwrap();
// Gets average mean-squared-error
let binary_pixels = binarize_buffer(arc.0.clone(), bin_params);
mse(binary_pixels, &arc.1)
}
// Mean-squared-error
fn mse(prediction: Image, target: &Image) -> f64 {
let n = target.len() * target[0].len();
prediction
.iter()
.flatten()
.zip(target.iter().flatten())
.map(|(p, t)| difference(p, t).powf(2.))
.sum::<f64>()
/ (2. * n as f64)
}
#[rustfmt::skip]
fn difference(p: &Pixel, t: &Pixel) -> f64 {
p.luma as f64 - t.luma as f64
}
}
pub fn from_raw(raw: &[u8], (_i_size, j_size): (usize, usize)) -> Vec<Vec<Pixel>> {
(0..raw.len())
.step_by(j_size)
.map(|index| {
raw[index..index + j_size]
.iter()
.map(Pixel::from)
.collect::<Vec<Pixel>>()
})
.collect()
}
pub fn binarize_buffer(
mut img: Vec<Vec<Pixel>>,
(_, _, local_luma_boundary, local_field_reach, local_field_size): (u8, u8, u8, usize, usize),
) -> Vec<Vec<Pixel>> {
let (i_size, j_size) = (img.len(), img[0].len());
let i_chunks = (i_size as f32 / local_field_size as f32).ceil() as usize;
let j_chunks = (j_size as f32 / local_field_size as f32).ceil() as usize;
let mut local_luma: Vec<Vec<u8>> = vec![vec![u8::default(); j_chunks]; i_chunks];
// Gets average luma in local fields
// O((s+r)^2*(n/s)*(m/s)) : s = local field size, r = local field reach
for (i_chunk, i) in (0..i_size).step_by(local_field_size).enumerate() {
let i_range = zero_checked_sub(i, local_field_reach)
..cmp::min(i + local_field_size + local_field_reach, i_size);
let i_range_length = i_range.end - i_range.start;
for (j_chunk, j) in (0..j_size).step_by(local_field_size).enumerate() {
let j_range = zero_checked_sub(j, local_field_reach)
..cmp::min(j + local_field_size + local_field_reach, j_size);
let j_range_length = j_range.end - j_range.start;
let total: u32 = i_range
.clone()
.map(|i_range_indx| {
img[i_range_indx][j_range.clone()]
.iter()
.map(|p| p.luma as u32)
.sum::<u32>()
})
.sum();
local_luma[i_chunk][j_chunk] = (total / (i_range_length * j_range_length) as u32) as u8;
}
}
// Apply binarization
// O(nm)
for i in 0..i_size {
let i_group: usize = i / local_field_size; // == floor(i as f32 / local_field_size as f32) as usize
for j in 0..j_size {
let j_group: usize = j / local_field_size;
// Local average boundaries
// --------------------------------
if let Some(local) = local_luma[i_group][j_group].checked_sub(local_luma_boundary) {
if img[i][j].luma < local {
img[i][j].luma = 0;
continue;
}
}
if let Some(local) = local_luma[i_group][j_group].checked_add(local_luma_boundary) {
if img[i][j].luma > local {
img[i][j].luma = 255;
continue;
}
}
// White is the negative (false/0) colour in our binarization, thus this is our else case
img[i][j].luma = 255;
}
}
img
}
#[rustfmt::skip]
fn zero_checked_sub(a: usize, b: usize) -> usize { if a > b { a - b } else { 0 } }
Project zip (in case you'd rather not spend time setting it up).
Else, here are the images being used as /target.jpg and /example.jpg (it shouldn't matter it being specifically these images, any should work):
And Cargo.toml dependencies:
[dependencies]
rand = "0.8.4"
itertools = "0.10.1" # izip!
num_cpus = "1.13.0" # Multi-threading
print_duration = "1.0.0" # Printing progress
num = "0.4.0" # Generics
rand_distr = "0.4.1" # Normal distribution
image = "0.23.14"
serde = { version="1.0.118", features=["derive"] }
serde_json = "1.0.50"
I do feel rather reluctant to post such a large question and
inevitably require people to read a few hundred lines (especially given the project doesn't work in a playground), but I'm really lost here and can see no other way to communicate the whole area of the problem. Apologies for this.
As noted, I have tried for a while to figure out what is happening here, but I have come up short, any help would be really appreciate.
Some basic debugging (aka println! everywhere) shows that your performance problem is not related to the multithreading at all. It just happens randomly, and when there are 24 threads doing their job, the fact that one is randomly stalling is not noticeable, but when there is only one or two threads left, they stand out as slow.
But where is this performance bottleneck? Well, you are stating it yourself in the code: in binary_buffer you say:
// Gets average luma in local fields
// O((s+r)^2*(n/s)*(m/s)) : s = local field size, r = local field reach
The values of s and r seem to be random values between 0 and 255, while n is the length of a image row, in bytes 3984 * 3 = 11952, and m is the number of rows 2271.
Now, most of the times that O() is around a few millions, quite manageable. But if s happens to be small and r big, such as (3, 200) then the number of computations blows up to over 1e11!
Fortunately I think you can define the ranges of those values in the original call to random_search so a bit of tweaking there should send you back to reasonable complexity. Changing the ranges to:
[0..255, 0..255, 0..255, 1..255, 20..255],
// ^ here
seems to do the trick for me.
PS: These lines at the beginning of binary_buffer were key to discover this:
let o = (i_size / local_field_size) * (j_size / local_field_size) * (local_field_size + local_field_reach).pow(2);
println!("\nO() = {}", o);

How do I correctly implement std::iter::Step in Rust 1.25 nightly?

I am trying to implement a struct to represent a date, which can be used with range syntax (for d in start..end { }). I know there are already crates out there that handle dates, but I am doing this as an exercise.
Here is the struct:
type DayOfMonth = u8;
type Month = u8;
type Year = u16;
#[derive(PartialEq, Eq, Clone)]
pub struct Date {
pub year: Year,
pub month: Month,
pub day: DayOfMonth
}
Here is how I would like to use it:
fn print_dates() {
let start = Date { year: 1999, month: 1, day: 1 };
let end = Date { year: 1999, month: 12, day: 31 };
for d in start..end {
println!("{}-{}-{}", d.year, d.month, d.day);
}
}
I originally tried implementing the Iterator trait, but then when I tried to use range syntax, I got a compiler error saying I needed to implement Step instead.
The documentation shows this signature for the Step trait.
pub trait Step: PartialOrd<Self> + Clone {
fn steps_between(start: &Self, end: &Self) -> Option<usize>;
fn replace_one(&mut self) -> Self;
fn replace_zero(&mut self) -> Self;
fn add_one(&self) -> Self;
fn sub_one(&self) -> Self;
fn add_usize(&self, n: usize) -> Option<Self>;
}
I've already implemented Ord and PartialOrd:
impl Ord for Date {
fn cmp(&self, other: &Self) -> Ordering {
match self.year.cmp(&other.year) {
Ordering::Equal =>
match self.month.cmp(&other.month) {
Ordering::Equal =>
self.day.cmp(&other.day),
ord => ord
},
ord => ord
}
}
}
impl PartialOrd for Date {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
I am implementing Clone automatically with #[derive(Clone)].
I started implementing Step, but there are some methods that I cannot figure out what to do with. Here is what I have so far:
impl Step for Date {
fn steps_between(start: &self, end: &self) -> Option<usize> {
//is_valid_date checks that the month is not > 12 and other rules like that
if is_valid_date(start) && is_valid_date(end) {
//get_epoch_day_number gets the number of days since 1900-01-01
let diff = get_epoch_day_number(end) - get_epoch_day_number(start);
Some(diff)
}
else { None }
}
fn add_one(&self) -> Self {
//Try the next day
let mut next = Date {
year: self.year,
month: self.month,
day: self.day + 1
};
//If not valid, try the 1st of the next month
if !is_valid_date(&next) {
next = Date {
year: self.year,
month: self.month + 1,
day: 1
};
}
//If not valid, try the 1st of the next year
if !is_valid_date(&next) {
next = Date {
year: self.year + 1,
month: 1,
day: 1
};
}
next
}
fn sub_one(&self) -> Self {
//Try the prev day
let mut prev = Date {
year: self.year,
month: self.month,
day: self.day - 1
};
//If not valid, try the last of the prev month
if !is_valid_date(&prev) {
let m = self.month - 1;
prev = Date {
year: self.year,
month: m,
day: get_month_length(self.year, m)
};
}
//If not valid, try the last of the prev year
if !is_valid_date(&prev) {
prev = Date {
year: self.year - 1,
month: 12,
day: 31
};
}
prev
}
fn add_usize(&self, n: usize) -> Self {
//This is really inefficient, but that's not important
let mut result = self;
for i in 1..n+1 {
result = result.add_one();
}
result
}
fn replace_one(&mut self) -> Self {
// ?
}
fn replace_zero(&mut self) -> Self {
// ?
}
}
I am really stumped on what replace_one and replace_zero are supposed to do. The documentation says:
Replaces this step with 1, returning itself. and Replaces this step with 0, returning itself.
Does my struct need to have zero and one identity values just to be used in a range? Shouldn't add_one be enough?
The wording used by the documentation is also a little unclear. If we replace x with 1 and return "itself", is "it" x or 1?
I just looked at Rust's code where those methods are used. The only uses in whole of rustc's repository are to implement RangeInclusive operations. An empty RangeInclusive is represented as a range from 1 to 0, so the next, next_back and nth methods need to be able to get those somehow, and that's what replace_one and replace_zero are for.
I would suggest opening an issue on rustc's GitHub to make the documentation better, and possibly change the name of these methods.

What's the Rust way to modify a structure within nested loops?

Given is an array of bodies that interact in some way with each other. As a newbie I approached it as I would do it in some other language:
struct Body {
x: i16,
y: i16,
v: i16,
}
fn main() {
let mut bodies = Vec::<Body>::new();
bodies.push(Body { x: 10, y: 10, v: 0 });
bodies.push(Body { x: 20, y: 30, v: 0 });
// keep it simple and loop only twice
for i in 0..2 {
println!("Turn {}", i);
for b_outer in bodies.iter() {
println!("x:{}, y:{}, v:{}", b_outer.x, b_outer.y, b_outer.v);
let mut a = b_outer.v;
for b_inner in bodies.iter() {
// for simplicity I ignore here to continue in case b_outer == b_inner
// just do some calculation
a = a + b_outer.x * b_inner.x;
println!(
" x:{}, y:{}, v:{}, a:{}",
b_inner.x,
b_inner.y,
b_inner.v,
a
);
}
// updating b_outer.v fails
b_outer.v = a;
}
}
}
Updating of b_outer.v after the inner loop has finished fails:
error[E0594]: cannot assign to immutable field `b_outer.v`
--> src/main.rs:32:13
|
32 | b_outer.v = a;
| ^^^^^^^^^^^^^ cannot mutably borrow immutable field
Making b_outer mutable:
for b_outer in bodies.iter_mut() { ...
doesn't work either:
error[E0502]: cannot borrow `bodies` as mutable because it is also borrowed as immutable
--> src/main.rs:19:32
|
16 | for b_outer in bodies.iter() {
| ------ immutable borrow occurs here
...
19 | for b_inner in bodies.iter_mut() {
| ^^^^^^ mutable borrow occurs here
...
33 | }
| - immutable borrow ends here
Now I'm stuck. What's the Rust approach to update b_outer.v after the inner loop has finished?
For what it's worth, I think the error message is telling you that your code has a logic problem. If you update the vector between iterations of the inner loop, then those changes will be used for subsequent iterations. Let's look at a smaller example where we compute the windowed-average of an array item and its neighbors:
[2, 0, 2, 0, 2] // input
[2/3, 4/3, 2/3, 4/3, 2/3] // expected output (out-of-bounds counts as 0)
[2/3, 0, 2, 0, 2] // input after round 1
[2/3, 8/9, 2, 0, 2] // input after round 2
[2/3, 8/9, 26/9, 0, 2] // input after round 3
// I got bored here
I'd suggest computing the output into a temporary vector and then swap them:
#[derive(Debug)]
struct Body {
x: i16,
y: i16,
v: i16,
}
fn main() {
let mut bodies = vec![Body { x: 10, y: 10, v: 0 }, Body { x: 20, y: 30, v: 0 }];
for _ in 0..2 {
let next_bodies = bodies
.iter()
.map(|b| {
let next_v = bodies
.iter()
.fold(b.v, { |a, b_inner| a + b.x * b_inner.x });
Body { v: next_v, ..*b }
})
.collect();
bodies = next_bodies;
}
println!("{:?}", bodies);
}
Output:
[Body { x: 10, y: 10, v: 600 }, Body { x: 20, y: 30, v: 1200 }]
If you really concerned about memory performance, you could create a total of two vectors, size them appropriately, then alternate between the two. The code would be uglier though.
As Matthieu M. said, you could use Cell or RefCell, which both grant you inner mutability:
use std::cell::Cell;
#[derive(Debug, Copy, Clone)]
struct Body {
x: i16,
y: i16,
v: i16,
}
fn main() {
let bodies = vec![
Cell::new(Body { x: 10, y: 10, v: 0 }),
Cell::new(Body { x: 20, y: 30, v: 0 }),
];
for _ in 0..2 {
for b_outer_cell in &bodies {
let mut b_outer = b_outer_cell.get();
let mut a = b_outer.v;
for b_inner in &bodies {
let b_inner = b_inner.get();
a = a + b_outer.x * b_inner.x;
}
b_outer.v = a;
b_outer_cell.set(b_outer);
}
}
println!("{:?}", bodies);
}
[Cell { value: Body { x: 10, y: 10, v: 600 } }, Cell { value: Body { x: 20, y: 30, v: 1200 } }]
I know the question is like 2 years old, but I got curious about it.
This C# program produces the original desired output:
var bodies = new[] { new Body { X = 10, Y = 10, V = 0 },
new Body { X = 20, Y = 30, V = 0 } };
for (int i = 0; i < 2; i++)
{
Console.WriteLine("Turn {0}", i);
foreach (var bOuter in bodies)
{
Console.WriteLine("x:{0}, y:{1}, v:{2}", bOuter.X, bOuter.Y, bOuter.V);
var a = bOuter.V;
foreach (var bInner in bodies)
{
a = a + bOuter.X * bInner.X;
Console.WriteLine(" x:{0}, y:{1}, v:{2}, a:{3}", bInner.X, bInner.Y, bInner.V, a);
}
bOuter.V = a;
}
}
Since only v is ever changed, we could change the struct to something like this:
struct Body {
x: i16,
y: i16,
v: Cell<i16>,
}
Now I'm able to mutate v, and the program becomes:
// keep it simple and loop only twice
for i in 0..2 {
println!("Turn {}", i);
for b_outer in bodies.iter() {
let mut a = b_outer.v.get();
println!("x:{}, y:{}, v:{}", b_outer.x, b_outer.y, a);
for b_inner in bodies.iter() {
a = a + (b_outer.x * b_inner.x);
println!(
" x:{}, y:{}, v:{}, a:{}",
b_inner.x,
b_inner.y,
b_inner.v.get(),
a
);
}
b_outer.v.set(a);
}
}
It produces the same output as the C# program above. The "downside" is that whenever you want to work with v, you need use get() or into_inner(). There may be other downsides I'm not aware of.
I decided to split the structure in one that is used as a base for the calculation (input) in the inner loop (b_inner) and one that gathers the results (output). After the inner loop finished, the input structure is updated in the outer loop (b_outer) and calculation starts with the next body.
What's now not so nice that I have to deal with two structures and you don't see their relation from the declaration.
#[derive(Debug)]
struct Body {
x: i16,
y: i16,
}
struct Velocity {
vx: i16,
}
fn main() {
let mut bodies = Vec::<Body>::new();
let mut velocities = Vec::<Velocity>::new();
bodies.push(Body { x: 10, y: 10 });
bodies.push(Body { x: 20, y: 30 });
velocities.push(Velocity { vx: 0 });
velocities.push(Velocity { vx: 0 });
// keep it simple and loop only twice
for i in 0..2 {
println!("Turn {}", i);
for (i, b_outer) in bodies.iter().enumerate() {
println!("x:{}, y:{}, v:{}", b_outer.x, b_outer.y, velocities[i].vx);
let v = velocities.get_mut(i).unwrap();
let mut a = v.vx;
for b_inner in bodies.iter() {
// for simplicity I ignore here to continue in case b_outer == b_inner
// just do some calculation
a = a + b_outer.x * b_inner.x;
println!(" x:{}, y:{}, v:{}, a:{}", b_inner.x, b_inner.y, v.vx, a);
}
v.vx = a;
}
}
println!("{:?}", bodies);
}
Output:
[Body { x: 10, y: 10 }, Body { x: 20, y: 30 }]

Resources