Understanding Ord and PartialOrd macros - rust
So, I have this working code:
use std::cmp::Ordering;
use std::fmt;
mod errors;
use errors::VersionParseError;
#[derive(Debug, Eq, PartialEq)]
pub struct Version {
epoch: u64,
version: String,
release: u64,
}
impl Version {
pub fn new(epoch: u64, version: impl Into<String>, release: u64) -> Self {
Self {
epoch,
version: version.into(),
release,
}
}
pub fn from_string(version: impl Into<String>) -> Result<Self, VersionParseError> {
let (epoch, version) = parse_epoch(&version.into())?;
let (version, release) = parse_release(&version)?;
Ok(Self::new(epoch, version, release))
}
}
impl PartialOrd for Version {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.epoch.cmp(&other.epoch) {
Ordering::Equal => match compare(&self.version, &other.version) {
Ordering::Equal => Some(self.release.cmp(&other.release)),
result => Some(result),
},
result => Some(result),
}
}
}
impl Ord for Version {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.epoch == 0 {
write!(f, "{}-{}", self.version, self.release)
} else {
write!(f, "{}:{}-{}", self.epoch, self.version, self.release)
}
}
}
fn parse_epoch(version: &String) -> Result<(u64, String), VersionParseError> {
match version.split_once(':') {
Some((epoch, remainder)) => match epoch.parse::<u64>() {
Ok(epoch) => Ok((epoch, String::from(remainder))),
Err(_) => Err(VersionParseError::InvalidEpoch),
},
None => Ok((0, String::from(version))),
}
}
fn parse_release(version: &String) -> Result<(String, u64), VersionParseError> {
match version.split_once('-') {
Some((remainder, release)) => match release.parse::<u64>() {
Ok(release) => Ok((String::from(remainder), release)),
Err(_) => Err(VersionParseError::InvalidRelease),
},
None => Err(VersionParseError::NoReleaseSpecified),
}
}
fn compare(lhs: &String, rhs: &String) -> Ordering {
let l_segments = segments(lhs);
let r_segments = segments(rhs);
for (l_segment, r_segment) in l_segments.iter().zip(r_segments.iter()) {
match compare_segments(l_segment, r_segment) {
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Less => {
return Ordering::Less;
}
_ => {
continue;
}
}
}
l_segments.len().cmp(&r_segments.len())
}
fn segments(version: &String) -> Vec<String> {
normalize(version)
.split(".")
.map(|segment| String::from(segment))
.collect()
}
fn normalize(version: &String) -> String {
version
.chars()
.map(|chr| if chr.is_alphanumeric() { chr } else { '.' })
.collect()
}
fn compare_segments(lhs: &str, rhs: &str) -> Ordering {
let l_blocks = blocks(lhs);
let r_blocks = blocks(rhs);
let mut last: usize = 0;
for (index, (l_block, r_block)) in l_blocks.iter().zip(r_blocks.iter()).enumerate() {
last = index;
match compare_blocks(l_block, r_block) {
Ordering::Equal => {
continue;
}
ordering => {
return ordering;
}
}
}
match l_blocks
.iter()
.nth(last + 1)
.unwrap_or(&String::new())
.chars()
.nth(0)
{
Some(chr) => {
if chr.is_ascii_digit() {
Ordering::Greater
} else {
Ordering::Less
}
}
None => match r_blocks
.iter()
.nth(last + 1)
.unwrap_or(&String::new())
.chars()
.nth(0)
{
Some(chr) => {
if chr.is_ascii_digit() {
Ordering::Less
} else {
Ordering::Greater
}
}
None => Ordering::Equal,
},
}
}
fn blocks(segment: &str) -> Vec<String> {
let mut result = Vec::new();
let mut block = String::new();
for chr in segment.chars() {
match block.chars().nth(0) {
Some(current) => {
if same_type(&chr, ¤t) {
block.push(chr);
} else {
result.push(block.clone());
block.clear();
block.push(chr);
}
}
None => {
block.push(chr);
}
}
}
if !block.is_empty() {
result.push(block.clone());
}
result
}
fn same_type(lhs: &char, rhs: &char) -> bool {
lhs.is_ascii_digit() == rhs.is_ascii_digit()
}
fn compare_blocks(lhs: &String, rhs: &String) -> Ordering {
if lhs == rhs {
return Ordering::Equal;
}
let l_is_number = lhs.chars().all(|chr| chr.is_ascii_digit());
let r_is_number = rhs.chars().all(|chr| chr.is_ascii_digit());
if l_is_number && r_is_number {
compare_alpha(lhs, rhs)
} else if l_is_number && !r_is_number {
Ordering::Greater
} else if !l_is_number && r_is_number {
Ordering::Less
} else {
lhs.cmp(rhs)
}
}
fn compare_alpha(lhs: &str, rhs: &str) -> Ordering {
let lhs = lhs.trim_start_matches('0');
let rhs = rhs.trim_start_matches('0');
match lhs.len().cmp(&rhs.len()) {
Ordering::Equal => lhs.cmp(&rhs),
ordering => ordering,
}
}
I also tried to get rid of the presumably redundant implementation of Ord and use the appropriate macro to derive the default.
However, when I remove the implementation of Ord and add the macro to the derive list, the behaviour of the comparison operations (<, ==, >) changes.
Why is that? What is the default implementation of cmp() derived by the Ord macro?
I could not find it in the standard library, since it is a compiler built-in.
Since the implementation of partial_cmp() is complete, I also tried to implement Ord with cmp() analog to the current partial_cmp (except the Option) and derive PartialOrd with the appropriate macro. However, this also changes, i.e. breaks the behaviour of the comparison operators.
So my question is how Ord and PartialOrd play together and wether I really must implement both traits manually as I did above?
Addendum:
Here's the complete project:
Cargo.toml:
`Cargo.toml`:
```toml
[package]
name = "librucman"
version = "0.1.1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
[dev-dependencies]
once_cell = "*"
./tests/vercmp.rs:
use librucman::version::Version;
use std::cmp::Ordering;
mod common;
use common::{load_version_pair, EQUAL, GREATER_THAN, LESS_THAN, VERSIONS};
#[test]
fn version_parsing() {
for (string, version) in VERSIONS.iter() {
assert_eq!(*version, Version::from_string(*string).unwrap());
}
}
#[test]
fn version_comparison() {
for (lhs, rhs) in EQUAL.map(load_version_pair) {
assert_eq!(lhs, rhs);
assert_eq!(Ordering::Equal, lhs.cmp(&rhs))
}
for (lhs, rhs) in GREATER_THAN.map(load_version_pair) {
assert!(lhs > rhs);
assert_eq!(Ordering::Greater, lhs.cmp(&rhs))
}
for (lhs, rhs) in LESS_THAN.map(load_version_pair) {
assert!(lhs < rhs);
assert_eq!(Ordering::Less, lhs.cmp(&rhs))
}
}
./tests/common/mod.rs:
use once_cell::sync::Lazy;
use std::collections::HashMap;
use librucman::version::Version;
pub static VERSIONS: Lazy<HashMap<&'static str, Version>> = Lazy::new(|| {
HashMap::from([
(
"1:2.3.5+r3+gd9d61d87f-3",
Version::new(1, "2.3.5+r3+gd9d61d87f", 3),
),
("2.28.3-1", Version::new(0, "2.28.3", 1)),
("2.3.2.post1-1", Version::new(0, "2.3.2.post1", 1)),
("20220913.f09bebf-1", Version::new(0, "20220913.f09bebf", 1)),
(
"2:2.06.r322.gd9b4638c5-4",
Version::new(2, "2.06.r322.gd9b4638c5", 4),
),
("4.3-3", Version::new(0, "4.3", 3)),
(
"6.04.pre2.r11.gbf6db5b4-3",
Version::new(0, "6.04.pre2.r11.gbf6db5b4", 3),
),
("7.4.3-1", Version::new(0, "7.4.3", 1)),
("r2322+3aebf69d-1", Version::new(0, "r2322+3aebf69d", 1)),
("0.4.4-1", Version::new(0, "0.4.4", 1)),
("2.14.2-363", Version::new(0, "2.14.2", 363)),
])
});
pub const EQUAL: [(&str, &str); 10] = [
("0.4.4-1", "0.4.4-1"),
("2.3.2.post1-1", "2.3.2.post1-1"),
("1:2.3.5+r3+gd9d61d87f-3", "1:2.3.5+r3+gd9d61d87f-3"),
("4.3-3", "4.3-3"),
("2.28.3-1", "2.28.3-1"),
("r2322+3aebf69d-1", "r2322+3aebf69d-1"),
("2:2.06.r322.gd9b4638c5-4", "2:2.06.r322.gd9b4638c5-4"),
("6.04.pre2.r11.gbf6db5b4-3", "6.04.pre2.r11.gbf6db5b4-3"),
("7.4.3-1", "7.4.3-1"),
("20220913.f09bebf-1", "20220913.f09bebf-1"),
];
pub const LESS_THAN: [(&str, &str); 48] = [
("0.4.4-1", "2.3.2.post1-1"),
("0.4.4-1", "1:2.3.5+r3+gd9d61d87f-3"),
("0.4.4-1", "4.3-3"),
("0.4.4-1", "2.28.3-1"),
("0.4.4-1", "2:2.06.r322.gd9b4638c5-4"),
("0.4.4-1", "6.04.pre2.r11.gbf6db5b4-3"),
("0.4.4-1", "7.4.3-1"),
("0.4.4-1", "20220913.f09bebf-1"),
("2.3.2.post1-1", "1:2.3.5+r3+gd9d61d87f-3"),
("2.3.2.post1-1", "4.3-3"),
("2.3.2.post1-1", "2.28.3-1"),
("2.3.2.post1-1", "2:2.06.r322.gd9b4638c5-4"),
("2.3.2.post1-1", "6.04.pre2.r11.gbf6db5b4-3"),
("2.3.2.post1-1", "7.4.3-1"),
("2.3.2.post1-1", "20220913.f09bebf-1"),
("1:2.3.5+r3+gd9d61d87f-3", "2:2.06.r322.gd9b4638c5-4"),
("4.3-3", "1:2.3.5+r3+gd9d61d87f-3"),
("4.3-3", "2:2.06.r322.gd9b4638c5-4"),
("4.3-3", "6.04.pre2.r11.gbf6db5b4-3"),
("4.3-3", "7.4.3-1"),
("4.3-3", "20220913.f09bebf-1"),
("2.28.3-1", "1:2.3.5+r3+gd9d61d87f-3"),
("2.28.3-1", "4.3-3"),
("2.28.3-1", "2:2.06.r322.gd9b4638c5-4"),
("2.28.3-1", "6.04.pre2.r11.gbf6db5b4-3"),
("2.28.3-1", "7.4.3-1"),
("2.28.3-1", "20220913.f09bebf-1"),
("r2322+3aebf69d-1", "0.4.4-1"),
("r2322+3aebf69d-1", "2.3.2.post1-1"),
("r2322+3aebf69d-1", "1:2.3.5+r3+gd9d61d87f-3"),
("r2322+3aebf69d-1", "4.3-3"),
("r2322+3aebf69d-1", "2.28.3-1"),
("r2322+3aebf69d-1", "2:2.06.r322.gd9b4638c5-4"),
("r2322+3aebf69d-1", "6.04.pre2.r11.gbf6db5b4-3"),
("r2322+3aebf69d-1", "7.4.3-1"),
("r2322+3aebf69d-1", "20220913.f09bebf-1"),
("6.04.pre2.r11.gbf6db5b4-3", "1:2.3.5+r3+gd9d61d87f-3"),
("6.04.pre2.r11.gbf6db5b4-3", "2:2.06.r322.gd9b4638c5-4"),
("6.04.pre2.r11.gbf6db5b4-3", "7.4.3-1"),
("6.04.pre2.r11.gbf6db5b4-3", "20220913.f09bebf-1"),
("7.4.3-1", "1:2.3.5+r3+gd9d61d87f-3"),
("7.4.3-1", "2:2.06.r322.gd9b4638c5-4"),
("7.4.3-1", "20220913.f09bebf-1"),
("20220913.f09bebf-1", "1:2.3.5+r3+gd9d61d87f-3"),
("20220913.f09bebf-1", "2:2.06.r322.gd9b4638c5-4"),
("41.1-2", "42beta+r14+g2d9d76c-2"),
("1.14.50-1", "2022d-1"),
("5.15.6+kde+r50-1", "5.15.6+kde+r177-1"),
];
pub const GREATER_THAN: [(&str, &str); 52] = [
("0.4.4-1", "r2322+3aebf69d-1"),
("2.3.2.post1-1", "0.4.4-1"),
("2.3.2.post1-1", "r2322+3aebf69d-1"),
("1:2.3.5+r3+gd9d61d87f-3", "0.4.4-1"),
("1:2.3.5+r3+gd9d61d87f-3", "2.3.2.post1-1"),
("1:2.3.5+r3+gd9d61d87f-3", "4.3-3"),
("1:2.3.5+r3+gd9d61d87f-3", "2.28.3-1"),
("1:2.3.5+r3+gd9d61d87f-3", "r2322+3aebf69d-1"),
("1:2.3.5+r3+gd9d61d87f-3", "6.04.pre2.r11.gbf6db5b4-3"),
("1:2.3.5+r3+gd9d61d87f-3", "7.4.3-1"),
("1:2.3.5+r3+gd9d61d87f-3", "20220913.f09bebf-1"),
("4.3-3", "0.4.4-1"),
("4.3-3", "2.3.2.post1-1"),
("4.3-3", "2.28.3-1"),
("4.3-3", "r2322+3aebf69d-1"),
("2.28.3-1", "0.4.4-1"),
("2.28.3-1", "2.3.2.post1-1"),
("2.28.3-1", "r2322+3aebf69d-1"),
("2:2.06.r322.gd9b4638c5-4", "0.4.4-1"),
("2:2.06.r322.gd9b4638c5-4", "2.3.2.post1-1"),
("2:2.06.r322.gd9b4638c5-4", "1:2.3.5+r3+gd9d61d87f-3"),
("2:2.06.r322.gd9b4638c5-4", "4.3-3"),
("2:2.06.r322.gd9b4638c5-4", "2.28.3-1"),
("2:2.06.r322.gd9b4638c5-4", "r2322+3aebf69d-1"),
("2:2.06.r322.gd9b4638c5-4", "6.04.pre2.r11.gbf6db5b4-3"),
("2:2.06.r322.gd9b4638c5-4", "7.4.3-1"),
("2:2.06.r322.gd9b4638c5-4", "20220913.f09bebf-1"),
("6.04.pre2.r11.gbf6db5b4-3", "0.4.4-1"),
("6.04.pre2.r11.gbf6db5b4-3", "2.3.2.post1-1"),
("6.04.pre2.r11.gbf6db5b4-3", "4.3-3"),
("6.04.pre2.r11.gbf6db5b4-3", "2.28.3-1"),
("6.04.pre2.r11.gbf6db5b4-3", "r2322+3aebf69d-1"),
("7.4.3-1", "0.4.4-1"),
("7.4.3-1", "2.3.2.post1-1"),
("7.4.3-1", "4.3-3"),
("7.4.3-1", "2.28.3-1"),
("7.4.3-1", "r2322+3aebf69d-1"),
("7.4.3-1", "6.04.pre2.r11.gbf6db5b4-3"),
("20220913.f09bebf-1", "0.4.4-1"),
("20220913.f09bebf-1", "2.3.2.post1-1"),
("20220913.f09bebf-1", "4.3-3"),
("20220913.f09bebf-1", "2.28.3-1"),
("20220913.f09bebf-1", "r2322+3aebf69d-1"),
("20220913.f09bebf-1", "6.04.pre2.r11.gbf6db5b4-3"),
("20220913.f09bebf-1", "7.4.3-1"),
("1.4rc5-14", "1.0.3.1-6"),
("2.38.0-2", "2.038ro+1.058it+1.018var-1"),
("1.21.0-1", "1.3-1"),
("0.8.2-5", "0.8-5"),
("3.2.13-1", "3.02a09-5"),
("9.0.2-1", "9.0p1-1"),
("3.2.2-2", "3.02a09-5"),
];
pub fn load_version_pair((lhs, rhs): (&str, &str)) -> (Version, Version) {
(
Version::from_string(lhs).unwrap(),
Version::from_string(rhs).unwrap(),
)
}
./src/version/errors.rs:
use std::fmt;
#[derive(Debug, Eq, PartialEq)]
pub enum VersionParseError {
InvalidEpoch,
InvalidRelease,
NoReleaseSpecified,
}
impl VersionParseError {
pub fn to_string(&self) -> &str {
match self {
Self::InvalidEpoch => "invalid epoch",
Self::InvalidRelease => "invalid release",
Self::NoReleaseSpecified => "no release specified",
}
}
}
impl fmt::Display for VersionParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_string())
}
}
./src/version.rs:
use std::cmp::Ordering;
use std::fmt;
mod errors;
use errors::VersionParseError;
#[derive(Debug, Eq, PartialEq)]
pub struct Version {
epoch: u64,
version: String,
release: u64,
}
impl Version {
pub fn new(epoch: u64, version: impl Into<String>, release: u64) -> Self {
Self {
epoch,
version: version.into(),
release,
}
}
pub fn from_string(version: impl Into<String>) -> Result<Self, VersionParseError> {
let (epoch, version) = parse_epoch(&version.into())?;
let (version, release) = parse_release(&version)?;
Ok(Self::new(epoch, version, release))
}
}
impl PartialOrd for Version {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.epoch.cmp(&other.epoch) {
Ordering::Equal => match compare(&self.version, &other.version) {
Ordering::Equal => Some(self.release.cmp(&other.release)),
result => Some(result),
},
result => Some(result),
}
}
}
impl Ord for Version {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.epoch == 0 {
write!(f, "{}-{}", self.version, self.release)
} else {
write!(f, "{}:{}-{}", self.epoch, self.version, self.release)
}
}
}
fn parse_epoch(version: &String) -> Result<(u64, String), VersionParseError> {
match version.split_once(':') {
Some((epoch, remainder)) => match epoch.parse::<u64>() {
Ok(epoch) => Ok((epoch, String::from(remainder))),
Err(_) => Err(VersionParseError::InvalidEpoch),
},
None => Ok((0, String::from(version))),
}
}
fn parse_release(version: &String) -> Result<(String, u64), VersionParseError> {
match version.split_once('-') {
Some((remainder, release)) => match release.parse::<u64>() {
Ok(release) => Ok((String::from(remainder), release)),
Err(_) => Err(VersionParseError::InvalidRelease),
},
None => Err(VersionParseError::NoReleaseSpecified),
}
}
fn compare(lhs: &String, rhs: &String) -> Ordering {
let l_segments = segments(lhs);
let r_segments = segments(rhs);
for (l_segment, r_segment) in l_segments.iter().zip(r_segments.iter()) {
match compare_segments(l_segment, r_segment) {
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Less => {
return Ordering::Less;
}
_ => {
continue;
}
}
}
l_segments.len().cmp(&r_segments.len())
}
fn segments(version: &String) -> Vec<String> {
normalize(version)
.split(".")
.map(|segment| String::from(segment))
.collect()
}
fn normalize(version: &String) -> String {
version
.chars()
.map(|chr| if chr.is_alphanumeric() { chr } else { '.' })
.collect()
}
fn compare_segments(lhs: &str, rhs: &str) -> Ordering {
let l_blocks = blocks(lhs);
let r_blocks = blocks(rhs);
let mut last: usize = 0;
for (index, (l_block, r_block)) in l_blocks.iter().zip(r_blocks.iter()).enumerate() {
last = index;
match compare_blocks(l_block, r_block) {
Ordering::Equal => {
continue;
}
ordering => {
return ordering;
}
}
}
match l_blocks
.iter()
.nth(last + 1)
.unwrap_or(&String::new())
.chars()
.nth(0)
{
Some(chr) => {
if chr.is_ascii_digit() {
Ordering::Greater
} else {
Ordering::Less
}
}
None => match r_blocks
.iter()
.nth(last + 1)
.unwrap_or(&String::new())
.chars()
.nth(0)
{
Some(chr) => {
if chr.is_ascii_digit() {
Ordering::Less
} else {
Ordering::Greater
}
}
None => Ordering::Equal,
},
}
}
fn blocks(segment: &str) -> Vec<String> {
let mut result = Vec::new();
let mut block = String::new();
for chr in segment.chars() {
match block.chars().nth(0) {
Some(current) => {
if same_type(&chr, ¤t) {
block.push(chr);
} else {
result.push(block.clone());
block.clear();
block.push(chr);
}
}
None => {
block.push(chr);
}
}
}
if !block.is_empty() {
result.push(block.clone());
}
result
}
fn same_type(lhs: &char, rhs: &char) -> bool {
lhs.is_ascii_digit() == rhs.is_ascii_digit()
}
fn compare_blocks(lhs: &String, rhs: &String) -> Ordering {
if lhs == rhs {
return Ordering::Equal;
}
let l_is_number = lhs.chars().all(|chr| chr.is_ascii_digit());
let r_is_number = rhs.chars().all(|chr| chr.is_ascii_digit());
if l_is_number && r_is_number {
compare_alpha(lhs, rhs)
} else if l_is_number && !r_is_number {
Ordering::Greater
} else if !l_is_number && r_is_number {
Ordering::Less
} else {
lhs.cmp(rhs)
}
}
fn compare_alpha(lhs: &str, rhs: &str) -> Ordering {
let lhs = lhs.trim_start_matches('0');
let rhs = rhs.trim_start_matches('0');
match lhs.len().cmp(&rhs.len()) {
Ordering::Equal => lhs.cmp(&rhs),
ordering => ordering,
}
}
./src/lib.rs:
extern crate core;
pub mod version;
I used https://play.rust-lang.org/ to have a look at the expanded macros (Tools -> Expand macros).
The auto implementation for PartialOrd is:
#[automatically_derived]
impl ::core::cmp::PartialOrd for Version {
#[inline]
fn partial_cmp(&self, other: &Version)
-> ::core::option::Option<::core::cmp::Ordering> {
match ::core::cmp::PartialOrd::partial_cmp(&self.epoch, &other.epoch)
{
::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
match ::core::cmp::PartialOrd::partial_cmp(&self.version,
&other.version) {
::core::option::Option::Some(::core::cmp::Ordering::Equal)
=>
::core::cmp::PartialOrd::partial_cmp(&self.release,
&other.release),
cmp => cmp,
},
cmp => cmp,
}
}
}
and for Ord:
#[automatically_derived]
impl ::core::cmp::Ord for Version {
#[inline]
fn cmp(&self, other: &Version) -> ::core::cmp::Ordering {
match ::core::cmp::Ord::cmp(&self.epoch, &other.epoch) {
::core::cmp::Ordering::Equal =>
match ::core::cmp::Ord::cmp(&self.version, &other.version) {
::core::cmp::Ordering::Equal =>
::core::cmp::Ord::cmp(&self.release, &other.release),
cmp => cmp,
},
cmp => cmp,
}
}
}
This explains, why the (actually sane) default implementations break my test cases.
To prevent this, I will wrap the version string in a custom struct and implement cmp() for it. This should solve my issue.
Related
Rust: using derive inside macro
i have a question on rust macros... i created a custom trait, and i made it derivable... but now if i try to derive the trait inside a macro it doesn't work (but if i write it manually it does work) here is the code: // this is my macro #[macro_export] macro_rules! compact { ( $x:ident, $y:ident ) => { #[derive(Vhdlizable,Debug)] struct Compact<T,E>{ $x: T, $y: E } }; } #[test] fn test_compact_macro(){ let a = 11_i32; let b = 0_u8; //this dos't work //compact!(a,b); //this works #[derive(Vhdlizable,Debug)] struct Compact<T, E,> { a: T, b: E, } } the error i get is the following: error[E0424]: expected value, found module `self` --> src\tests.rs:103:14 | 90 | #[derive(Vhdlizable,Debug)] | ---------- this function has a `self` parameter, but a macro invocation can only access identifiers it receives from parameters ... 103 | compact!(a,b); | ^ `self` value is a keyword only available in methods with a `self` parameter i have no idea what could cause this... it should be noted that if i only derive the Debug trait it works here is the code to make the trait Derivable: use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; use syn::{ parse_macro_input, parse_quote, Data, DeriveInput, Fields, GenericParam, Generics, }; #[test] fn test(){ } #[proc_macro_derive(Vhdlizable)] pub fn derive_vhdliizable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { // Parse the input tokens into a syntax tree. let input = parse_macro_input!(input as DeriveInput); // Used in the quasi-quotation below as `#name`. let name = input.ident; // Add a bound `T: HeapSize` to every type parameter T. let generics = add_trait_bounds(input.generics); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let necessary_bits = sum_necessary_bits(&input.data); let bit_rapresentation = concat_bit_representation(&input.data); let declaration_code = generate_delcaration_code(&input.data); let construction_code = generate_construction_code(&input.data); let deconstruction_code = generate_deconstruction_code(&input.data); let recreate_from_bits_code = generate_recursive_reconstruction(&input.data); let expanded = quote! { // The generated impl. impl #impl_generics Vhdlizable for #name #ty_generics #where_clause { fn get_necessary_bits() -> usize{ #necessary_bits } fn get_bit_representation(&self) -> Vec<bool>{ #bit_rapresentation } fn construct_from_bits(data: &[bool]) -> Result<Self,std::io::Error> where Self: Sized{ if data.len() != Self::get_necessary_bits(){ return Err(std::io::Error::new(std::io::ErrorKind::Other, "Length of input incompatible with length of output")); }; #recreate_from_bits_code } fn get_vhd_construction_code(variable_name: &str ,start_index: usize) -> String{ #construction_code } fn get_vhd_declaration_code(variable_name: &str ) -> String{ #declaration_code } fn get_vhd_deconstruction_code(variable_name: &str ,start_index: usize) -> String{ #deconstruction_code } } }; // Hand the output tokens back to the compiler. proc_macro::TokenStream::from(expanded) } // Add a bound `T: Vhdlizable` to every type parameter T. fn add_trait_bounds(mut generics: Generics) -> Generics { for param in &mut generics.params { if let GenericParam::Type(ref mut type_param) = *param { type_param.bounds.push(parse_quote!(Vhdlizable)); } } generics } // Generate an expression to sum up the heap size of each field. fn sum_necessary_bits(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse = fields.named.iter().map(|f| { let type_ = &f.ty; quote_spanned! {f.span()=> #type_::get_necessary_bits() } }); quote! { 0 #(+ #recurse)* } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } } fn concat_bit_representation(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; quote_spanned! {f.span()=> self.#name.get_bit_representation() } }).collect(); quote! { let mut v = vec![false;0]; #(v.extend(#recurse);)* v } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } } fn generate_delcaration_code(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; let type_ = &f.ty; quote_spanned! {f.span()=> //self.#name.get_delcaration_code("name") #type_::get_vhd_declaration_code( &format!("{}_{}",variable_name,stringify!(#name))[..] ) } }).collect(); quote! { let mut s = String::new(); #(s.push_str(&(#recurse)[..]);)* s } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } } fn generate_construction_code(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; let type_ = &f.ty; quote_spanned! {f.span()=> #type_::get_vhd_construction_code( &format!("{}_{}",variable_name,stringify!(#name))[..], start_from ) } }).collect(); let sizes: Vec<_> = fields.named.iter().map(|f| { let type_ = &f.ty; quote_spanned! {f.span()=> #type_::get_necessary_bits() } }).collect(); quote! { let mut start_from = start_index; let mut s = String::new(); #( s.push_str(&(#recurse)[..]); start_from += #sizes; )* s } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } } fn generate_deconstruction_code(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; let type_ = &f.ty; quote_spanned! {f.span()=> #type_::get_vhd_deconstruction_code( &format!("{}_{}",variable_name,stringify!(#name))[..], start_from ) } }).collect(); let sizes: Vec<_> = fields.named.iter().map(|f| { let type_ = &f.ty; quote_spanned! {f.span()=> #type_::get_necessary_bits() } }).collect(); quote! { let mut start_from = start_index; let mut s = String::new(); #( s.push_str(&(#recurse)[..]); start_from += #sizes; )* s } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } } fn generate_recursive_reconstruction(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => { match data.fields { Fields::Named(ref fields) => { let recurse: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; let type_ = &f.ty; quote_spanned! {f.span()=> let n = #type_::get_necessary_bits(); let v = &data[counter..(counter+n)]; counter += n; let #name: #type_ = #type_::construct_from_bits(v)?; } }).collect(); let names: Vec<_> = fields.named.iter().map(|f| { let name = &f.ident; quote_spanned! {f.span()=> #name } }).collect(); quote! { let mut counter = 0_usize; #( #recurse )* Ok(Self{ #( #names, )* }) } } Fields::Unnamed(_) => { unimplemented!("Impossible to derive this trait on unnamed Type") } Fields::Unit => { unimplemented!("Impossible to derive this trait on unit Type") } } } Data::Enum(_) | Data::Union(_) => unimplemented!(), } }
i switch to nightly rust to expand the code... and it works now... if you are still interested in or, here is the expanded code: #[allow(unused_imports)] #[allow(dead_code)] mod tests { use std::fmt::Debug; #[allow(unused_imports)] use crate::communicate_to_vhdl::{Communicator, Vhdlizable}; fn test_compact_macro() { let a = 11_i32; let b = 0_u8; struct Compact<T, E> { a: T, b: E, } impl<T: Vhdlizable, E: Vhdlizable> Vhdlizable for Compact<T, E> { fn get_necessary_bits() -> usize { 0 + T::get_necessary_bits() + E::get_necessary_bits() } fn get_bit_representation(&self) -> Vec<bool> { let mut v = ::alloc::vec::from_elem(false, 0); v.extend(self.a.get_bit_representation()); v.extend(self.b.get_bit_representation()); v } fn construct_from_bits(data: &[bool]) -> Result<Self, std::io::Error> where Self: Sized, { if data.len() != Self::get_necessary_bits() { return Err( std::io::Error::new( std::io::ErrorKind::Other, "Length of input incompatible with length of output", ), ); } let mut counter = 0_usize; let n = T::get_necessary_bits(); let v = &data[counter..(counter + n)]; counter += n; let a: T = T::construct_from_bits(v)?; let n = E::get_necessary_bits(); let v = &data[counter..(counter + n)]; counter += n; let b: E = E::construct_from_bits(v)?; Ok(Self { a, b }) } fn get_vhd_construction_code( variable_name: &str, start_index: usize, ) -> String { let mut start_from = start_index; let mut s = String::new(); s.push_str( &(T::get_vhd_construction_code( &{ let res = ::alloc::fmt::format( format_args!("{0}_{1}", variable_name, "a"), ); res }[..], start_from, ))[..], ); start_from += T::get_necessary_bits(); s.push_str( &(E::get_vhd_construction_code( &{ let res = ::alloc::fmt::format( format_args!("{0}_{1}", variable_name, "b"), ); res }[..], start_from, ))[..], ); start_from += E::get_necessary_bits(); s } fn get_vhd_declaration_code(variable_name: &str) -> String { let mut s = String::new(); s.push_str( &(T::get_vhd_declaration_code( &{ let res = ::alloc::fmt::format( format_args!("{0}_{1}", variable_name, "a"), ); res }[..], ))[..], ); s.push_str( &(E::get_vhd_declaration_code( &{ let res = ::alloc::fmt::format( format_args!("{0}_{1}", variable_name, "b"), ); res }[..], ))[..], ); s } fn get_vhd_deconstruction_code( variable_name: &str, start_index: usize, ) -> String { let mut start_from = start_index; let mut s = String::new(); s.push_str( &(T::get_vhd_deconstruction_code( &{ let res = ::alloc::fmt::format( format_args!("{0}_{1}", variable_name, "a"), ); res }[..], start_from, ))[..], "Compact", "a", &self.a, "b", &&self.b, ) } } } }
Abstract over mutability with iterator item
I've written a simple helper to loop over nibbles (4 bits) in an u8 slice. It uses an internal iterator over & u8 and essentially doubles the steps, where two steps both refer to the same underlying u8 but filter and shift the bits when viewed. I created a mutable version as well (not pasted here) using Rc and RefCell, which requires an underlying iterator over &mut u8. However I would like the read-only version to also work with iterators that provide mutable access. I've tried using I: 'a + Borrow<u8>, T: Iterator<Item = I> instead of the hard-coded &'a u8 and AsRef<u8> as well, but failed because with the inner byte becoming a non-reference, the borrowing occurs in my next() method where the borrowed values would escape their closure. What would be required to allow my Nibbler to work with iterators that either iterate over &u8 or &mut u8? pub enum Nibble<'a> { MSB(&'a u8), LSB(&'a u8), } impl Nibble<'_> { pub fn from_u8(input: &u8) -> (Nibble, Nibble) { let msb = Nibble::MSB(input); let lsb = Nibble::LSB(input); (msb, lsb) } pub fn get(&self) -> u8 { match self { Nibble::MSB(r) => (**r & 0b11110000) >> 4, Nibble::LSB(r) => **r & 0b00001111, } } } pub struct Nibbler<'a, T> { rest: Option<Nibble<'a>>, inner: T, } impl<T> Nibbler<'_, T> { pub fn new(inner: T) -> Self { Nibbler { inner, rest: None } } } impl<'a, T: Iterator<Item = &'a u8>> Iterator for Nibbler<'a, T> { type Item = Nibble<'a>; fn next(&mut self) -> Option<Self::Item> { self.rest.take().or_else(|| { self.inner.next().map(|byte| { let (msb, lsb) = Nibble::from_u8(byte); self.rest = Some(msb); lsb }) }) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_nibble_get() { let val = 0x79; let (msb, lsb) = Nibble::from_u8(&val); assert_eq!(msb.get(), 7); assert_eq!(lsb.get(), 9); } #[test] fn test_nibbler() { let t = [0x12, 0x34, 0x56, 0x78]; for (i, nibble) in Nibbler::new(t.iter()).enumerate() { match i { 0 => assert_eq!(nibble.get(), 2), 1 => assert_eq!(nibble.get(), 1), 2 => assert_eq!(nibble.get(), 4), 3 => assert_eq!(nibble.get(), 3), 4 => assert_eq!(nibble.get(), 6), 5 => assert_eq!(nibble.get(), 5), 6 => assert_eq!(nibble.get(), 8), 7 => assert_eq!(nibble.get(), 7), _ => {} } } } // #[test] // fn test_nibbler_mut() { // let t = [0x12, 0x34, 0x56, 0x78]; // for (i, nibble) in Nibbler::new(t.iter_mut()).enumerate() { // match i { // 0 => assert_eq!(nibble.get(), 2), // 1 => assert_eq!(nibble.get(), 1), // 2 => assert_eq!(nibble.get(), 4), // 3 => assert_eq!(nibble.get(), 3), // 4 => assert_eq!(nibble.get(), 6), // 5 => assert_eq!(nibble.get(), 5), // 6 => assert_eq!(nibble.get(), 8), // 7 => assert_eq!(nibble.get(), 7), // _ => {} // } // } // } } As requested by #chayim-friedman, here's my attempt with Borrow: use std::borrow::Borrow; impl<'a, I: Borrow<u8> + 'a, T: Iterator<Item = I>> Iterator for Nibbler<'a, T> { type Item = Nibble<'a>; fn next(&mut self) -> Option<Self::Item> { self.rest.take().or_else(|| { self.inner.next().map(|byte| { let (msb, lsb) = Nibble::from_u8(byte.borrow()); self.rest = Some(msb); lsb }) }) } } which errors with error[E0515]: cannot return value referencing function parameter `byte` --> src/utils/nibbler2.rs:42:17 | 40 | let (msb, lsb) = Nibble::from_u8(byte.borrow()); | ------------- `byte` is borrowed here 41 | self.rest = Some(msb); 42 | lsb | ^^^ returns a value referencing data owned by the current function
After struggling with this for a while, I finally found the solution in this answer: pub enum Nibble<'a> { MSB(&'a u8), LSB(&'a u8), } impl Nibble<'_> { pub fn from_u8(input: &u8) -> (Nibble, Nibble) { let msb = Nibble::MSB(input); let lsb = Nibble::LSB(input); (msb, lsb) } pub fn get(&self) -> u8 { match self { Nibble::MSB(r) => (**r & 0b11110000) >> 4, Nibble::LSB(r) => **r & 0b00001111, } } } pub struct Nibbler<'a, T> { rest: Option<Nibble<'a>>, inner: T, } impl<T> Nibbler<'_, T> { pub fn new(inner: T) -> Self { Nibbler { inner, rest: None } } } impl<'a, T> Iterator for Nibbler<'a, T> where T: Iterator, <T as Iterator>::Item: IntoNibbleRef<'a>, { type Item = Nibble<'a>; fn next(&mut self) -> Option<Self::Item> { self.rest.take().or_else(|| { self.inner.next().map(|byte| { let (msb, lsb) = Nibble::from_u8(byte.into_nibble_ref()); self.rest = Some(msb); lsb }) }) } } trait IntoNibbleRef<'a> { fn into_nibble_ref(self) -> &'a u8; } impl<'a> IntoNibbleRef<'a> for &'a u8 { fn into_nibble_ref(self) -> &'a u8 { self } } impl<'a> IntoNibbleRef<'a> for &'a mut u8 { fn into_nibble_ref(self) -> &'a u8 { self } } #[cfg(test)] mod tests { use super::*; #[test] fn test_nibble_get() { let val = 0x79; let (msb, lsb) = Nibble::from_u8(&val); assert_eq!(msb.get(), 7); assert_eq!(lsb.get(), 9); } #[test] fn test_nibbler() { let t = [0x12, 0x34, 0x56, 0x78]; for (i, nibble) in Nibbler::new(t.iter()).enumerate() { match i { 0 => assert_eq!(nibble.get(), 2), 1 => assert_eq!(nibble.get(), 1), 2 => assert_eq!(nibble.get(), 4), 3 => assert_eq!(nibble.get(), 3), 4 => assert_eq!(nibble.get(), 6), 5 => assert_eq!(nibble.get(), 5), 6 => assert_eq!(nibble.get(), 8), 7 => assert_eq!(nibble.get(), 7), _ => {} } } } #[test] fn test_nibbler_mut() { let mut t = [0x12, 0x34, 0x56, 0x78]; for (i, nibble) in Nibbler::new(t.iter_mut()).enumerate() { match i { 0 => assert_eq!(nibble.get(), 2), 1 => assert_eq!(nibble.get(), 1), 2 => assert_eq!(nibble.get(), 4), 3 => assert_eq!(nibble.get(), 3), 4 => assert_eq!(nibble.get(), 6), 5 => assert_eq!(nibble.get(), 5), 6 => assert_eq!(nibble.get(), 8), 7 => assert_eq!(nibble.get(), 7), _ => {} } } } } You need to introduce another nested trait that can convert both &u8 and &mut u8 into &u8, here called IntoNibbleRef. After a little more experimenting, I realized you can also implement such a trait generically: impl<'a, T> Iterator for Nibbler<'a, T> where T: Iterator, <T as Iterator>::Item: IntoImmutableRef<'a, u8>, { type Item = Nibble<'a>; fn next(&mut self) -> Option<Self::Item> { self.rest.take().or_else(|| { self.inner.next().map(|byte| { let (msb, lsb) = Nibble::from_u8(byte.into_immutable_ref()); self.rest = Some(msb); lsb }) }) } } trait IntoImmutableRef<'a, T> { fn into_immutable_ref(self) -> &'a T; } impl<'a, T> IntoImmutableRef<'a, T> for &'a T { fn into_immutable_ref(self) -> &'a T { self } } impl<'a, T> IntoImmutableRef<'a, T> for &'a mut T { fn into_immutable_ref(self) -> &'a T { self } }
How to check if a syn::Field is a struct in rust proc_macro
I have a proc_macro that produces a telemetry function to parse a structs member variables and it works great for non nested structs. I have found that I need to recursively call my handle_named_field function on any syn::Field that is a struct. The problem is that I cannot see a way to determine if a field is a struct or not, if I had a syn::Data variable it is trivial, like in my handle_data. How can I check inside of handle_named_fields(fields: &syn::FieldsNamed) if a field is a struct or not? use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, parse_quote, DeriveInput}; #[proc_macro_derive(Telemetry)] pub fn derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let output = parse_derive_input(&input); match output { syn::Result::Ok(tt) => tt, syn::Result::Err(err) => err.to_compile_error(), } .into() } fn parse_derive_input(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> { let struct_ident = &input.ident; let struct_data = parse_data(&input.data)?; let struct_fields = &struct_data.fields; let _struct_ident_str = format!("{}", struct_ident); let tele_body = match struct_fields { syn::Fields::Named(fields_named) => handle_named_fields(fields_named)?, syn::Fields::Unnamed(fields_unnamed) => { let field_indexes = (0..fields_unnamed.unnamed.len()).map(syn::Index::from); let field_indexes_str = (0..fields_unnamed.unnamed.len()).map(|idx| format!("{}", idx)); quote!(#( .field(#field_indexes_str, &self.#field_indexes) )*) } syn::Fields::Unit => quote!(), }; let telemetry_declaration = quote!( trait Telemetry { fn telemetry(self, header_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>, data_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>); } ); syn::Result::Ok( quote!( use std::thread; use std::collections::VecDeque; #[derive(Serialize, Deserialize, Default, Debug)] pub struct VariableDescription { pub var_name_length: usize, pub var_name: String, pub var_type_length: usize, pub var_type: String, pub var_size: usize, } #[derive(Serialize, Deserialize, Default, Debug)] pub struct TelemetryHeader { pub variable_descriptions: VecDeque::<VariableDescription>, } #telemetry_declaration impl Telemetry for #struct_ident { fn telemetry(self, header_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>, data_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>) { thread::spawn(move || { #tele_body; }); } } ) ) } fn handle_named_fields(fields: &syn::FieldsNamed) -> syn::Result<proc_macro2::TokenStream> { let idents = fields.named.iter().map(|f| &f.ident); let types = fields.named.iter().map(|f| &f.ty); let num_entities = fields.named.len(); let test = quote! ( let mut tele_header = TelemetryHeader {variable_descriptions: VecDeque::with_capacity(#num_entities)}; #( if() tele_header.variable_descriptions.push_back( VariableDescription { var_name_length: stringify!(#idents).len(), var_name: stringify!(#idents).to_string(), var_type_length: stringify!(#types).len(), var_type: stringify!(#types).to_string(), var_size: std::mem::size_of_val(&self.#idents), }); )* header_stream.lock().unwrap().write(&bincode::serialize(&tele_header).unwrap()).unwrap(); data_stream.lock().unwrap().write(&bincode::serialize(&self).unwrap()).unwrap(); ); syn::Result::Ok(test) } fn parse_named_field(field: &syn::Field) -> proc_macro2::TokenStream { let ident = field.ident.as_ref().unwrap(); let ident_str = format!("{}", ident); let ident_type = &field.ty; if field.attrs.is_empty() { quote!( if true { println!("TRUE"); } println!("Var Name Length: {}", stringify!(#ident_str).len()); println!("Var Name: {}", #ident_str); println!("Var Type Length: {}", stringify!(#ident_type).len()); println!("Var Type: {}", stringify!(#ident_type)); println!("Var Val: {}", &self.#ident); ) } else { quote!() } } fn parse_data(data: &syn::Data) -> syn::Result<&syn::DataStruct> { match data { syn::Data::Struct(data_struct) => syn::Result::Ok(data_struct), syn::Data::Enum(syn::DataEnum { enum_token, .. }) => syn::Result::Err( syn::Error::new_spanned(enum_token, "CustomDebug is not implemented for enums"), ), syn::Data::Union(syn::DataUnion { union_token, .. }) => syn::Result::Err( syn::Error::new_spanned(union_token, "CustomDebug is not implemented for unions"), ), } }
Why doesn't this repetition pattern work in Rust macro?
I'm trying to write a macro that generalizes serde_yaml deserialization for any struct so I don't have to rewrite the same thing over and over again. The only thing that is messign me up right now is the repetition inside a pattern. Macro: macro_rules! impl_struct_deserialization { ( $struct_type: path { $( $field_name:ident : $field_type:path ),* } ) => { paste! { impl<'de> Deserialize<'de> for $struct_type { fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de> { #[derive(Deserialize)] #[serde(field_identifier, rename_all = "lowercase")] enum Field { $( [<$field_name:camel>] ),* } struct [<$struct_type Visitor>]; impl<'de> Visitor<'de> for [<$struct_type Visitor>] { type Value = $struct_type; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str(&("struct " .to_owned() .push_str(&stringify!($struct_type)) ) ); } fn visit_map<V>(self, mut map: V) -> Result<$struct_type, V::Error> where V: MapAccess<'de> { $(let mut $field_name: Option<$field_type> = None;)* while let Some(key) = map.next_key()? { match key { $(Field::[<$field_type:camel>] => { if $field_name.is_some() { return Err(serde::de::Error::duplicate_field(stringify!($field_name))); } $field_name = Some(map.next_value()?); })* } } $( let $field_name = $field_name.ok_or_else(|| serde::de::Error::missing_field(stringify!($field_name)))?; )* Ok($struct_type::new($($field_name)*)) } } } } } }; } One of the calls: impl_struct_deserialization!( GBox { center: Vec3f, material: Material, radius: f32 } ); Error (repeats for every field apart from 1st): Thank you! UPD: used this as a reference
That specific error is due to a missing comma in a line near the bottom: Ok($struct_type::new($($field_name),*)) // ^
PyContextProtocol example for pyo3?
In the __enter__ method I want to return an object which is accessible in Rust and Python, so that Rust is able to update values in the object and Python can read the updated values. I would like to have something like this: #![feature(specialization)] use std::thread; use pyo3::prelude::*; use pyo3::types::{PyType, PyAny, PyDict}; use pyo3::exceptions::ValueError; use pyo3::PyContextProtocol; use pyo3::wrap_pyfunction; #[pyclass] #[derive(Debug, Clone)] pub struct Statistics { pub files: u32, pub errors: Vec<String>, } fn counter( root_path: &str, statistics: &mut Statistics, ) { statistics.files += 1; statistics.errors.push(String::from("Foo")); } #[pyfunction] pub fn count( py: Python, root_path: &str, ) -> PyResult<PyObject> { let mut statistics = Statistics { files: 0, errors: Vec::new(), }; let rc: std::result::Result<(), std::io::Error> = py.allow_threads(|| { counter(root_path, &mut statistics); Ok(()) }); let pyresult = PyDict::new(py); match rc { Err(e) => { pyresult.set_item("error", e.to_string()).unwrap(); return Ok(pyresult.into()) }, _ => () } pyresult.set_item("files", statistics.files).unwrap(); pyresult.set_item("errors", statistics.errors).unwrap(); Ok(pyresult.into()) } #[pyclass] #[derive(Debug)] pub struct Count { root_path: String, exit_called: bool, thr: Option<thread::JoinHandle<()>>, statistics: Statistics, } #[pymethods] impl Count { #[new] fn __new__( obj: &PyRawObject, root_path: &str, ) { obj.init(Count { root_path: String::from(root_path), exit_called: false, thr: None, statistics: Statistics { files: 0, errors: Vec::new(), }, }); } #[getter] fn statistics(&self) -> PyResult<Statistics> { Ok(Statistics { files: self.statistics.files, errors: self.statistics.errors.to_vec(), }) } } #[pyproto] impl<'p> PyContextProtocol<'p> for Count { fn __enter__(&mut self) -> PyResult<Py<Count>> { let gil = GILGuard::acquire(); self.thr = Some(thread::spawn(|| { counter(self.root_path.as_ref(), &mut self.statistics) })); Ok(PyRefMut::new(gil.python(), *self).unwrap().into()) } fn __exit__( &mut self, ty: Option<&'p PyType>, _value: Option<&'p PyAny>, _traceback: Option<&'p PyAny>, ) -> PyResult<bool> { self.thr.unwrap().join(); let gil = GILGuard::acquire(); self.exit_called = true; if ty == Some(gil.python().get_type::<ValueError>()) { Ok(true) } else { Ok(false) } } } #[pymodule(count)] fn init(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::<Count>()?; m.add_wrapped(wrap_pyfunction!(count))?; Ok(()) } But I'm getting the following error: error[E0477]: the type `[closure#src/lib.rs:90:39: 92:10 self:&mut &'p mut Count]` does not fulfill the required lifetime --> src/lib.rs:90:25 | 90 | self.thr = Some(thread::spawn(|| { | ^^^^^^^^^^^^^ | = note: type must satisfy the static lifetime
I've found a solution. The use of a guarded reference does the trick: #![feature(specialization)] use std::{thread, time}; use std::sync::{Arc, Mutex}; extern crate crossbeam_channel as channel; use channel::{Sender, Receiver, TryRecvError}; use pyo3::prelude::*; use pyo3::types::{PyType, PyAny}; use pyo3::exceptions::ValueError; use pyo3::PyContextProtocol; #[pyclass] #[derive(Debug, Clone)] pub struct Statistics { pub files: u32, pub errors: Vec<String>, } pub fn counter( statistics: Arc<Mutex<Statistics>>, cancel: &Receiver<()>, ) { for _ in 1..15 { thread::sleep(time::Duration::from_millis(100)); { let mut s = statistics.lock().unwrap(); s.files += 1; } match cancel.try_recv() { Ok(_) | Err(TryRecvError::Disconnected) => { println!("Terminating."); break; } Err(TryRecvError::Empty) => {} } } { let mut s = statistics.lock().unwrap(); s.errors.push(String::from("Foo")); } } #[pyclass] #[derive(Debug)] pub struct Count { exit_called: bool, statistics: Arc<Mutex<Statistics>>, thr: Option<thread::JoinHandle<()>>, cancel: Option<Sender<()>>, } #[pymethods] impl Count { #[new] fn __new__(obj: &PyRawObject) { obj.init(Count { exit_called: false, statistics: Arc::new(Mutex::new(Statistics { files: 0, errors: Vec::new(), })), thr: None, cancel: None, }); } #[getter] fn statistics(&self) -> PyResult<u32> { let s = Arc::clone(&self.statistics).lock().unwrap().files; Ok(s) } } #[pyproto] impl<'p> PyContextProtocol<'p> for Count { fn __enter__(&'p mut self) -> PyResult<()> { let statistics = self.statistics.clone(); let (sender, receiver) = channel::bounded(1); self.cancel = Some(sender); self.thr = Some(thread::spawn(move || { counter(statistics, &receiver) })); Ok(()) } fn __exit__( &mut self, ty: Option<&'p PyType>, _value: Option<&'p PyAny>, _traceback: Option<&'p PyAny>, ) -> PyResult<bool> { let _ = self.cancel.as_ref().unwrap().send(()); self.thr.take().map(thread::JoinHandle::join); let gil = GILGuard::acquire(); self.exit_called = true; if ty == Some(gil.python().get_type::<ValueError>()) { Ok(true) } else { Ok(false) } } } #[pyproto] impl pyo3::class::PyObjectProtocol for Count { fn __str__(&self) -> PyResult<String> { Ok(format!("{:?}", self)) } } #[pymodule(count)] fn init(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::<Count>()?; Ok(()) } Now I can run the following code: import time import count c = count.Count() with c: for _ in range(5): print(c.statistics) time.sleep(0.1) As the example shows thread cancelling also works, although a maybe nicer solution is using the crate thread_control.