Procedural macro for retrieving data from a nested struct by index - rust

I am trying to write a rust derive macro for retrieving data from a nested struct by index. The struct only contains primitive types u8, i8, u16, i16, u32, i32, u64, i64, or other structs thereof. I have an Enum which encapsulates the leaf field data in a common type which I call an Item(). I want the macro to create a .get() implementation which returns an item based on a u16 index.
Here is the desired behavior.
#[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
pub enum Item {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
}
struct NestedData {
a: u16,
b: i32,
}
#[derive(GetItem)]
struct Data {
a: i32,
b: u64,
c: NestedData,
}
let data = Data {
a: 42,
b: 1000,
c: NestedData { a: 500, b: -2 },
};
assert_eq!(data.get(0).unwrap(), Item::I32(42));
assert_eq!(data.get(1).unwrap(), Item::U64(1000));
assert_eq!(data.get(2).unwrap(), Item::U16(500));
assert_eq!(data.get(3).unwrap(), Item::I32(-2));
For this particular example, I want the macro to expand to the following...
impl Data {
pub fn get(&self, index: u16) -> Result<Item, Error> {
match index {
0 => Ok(Item::U16(self.a)),
1 => Ok(Item::I32(self.b)),
2 => Ok(Item::I32(self.c.a)),
3 => Ok(Item::U64(self.c.b)),
_ => Err(Error::BadIndex),
}
}
}
I have a working macro for a single layer struct, but I am not sure about how to modify it to support nested structs. Here is where I am at...
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DataStruct, DeriveInput, Fields, Type, TypePath};
pub fn impl_get_item(input: DeriveInput) -> syn::Result<TokenStream> {
let model_name = input.ident;
let fields = match input.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => fields.named,
_ => panic!("The GetItem derive can only be applied to structs"),
};
let mut matches = TokenStream::new();
let mut item_index: u16 = 0;
for field in fields {
let item_name = field.ident;
let item_type = field.ty;
let ts = match item_type {
Type::Path(TypePath { path, .. }) if path.is_ident("u8") => {
quote! {#item_index => Ok(Item::U8(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("i8") => {
quote! {#item_index => Ok(Item::I8(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("u16") => {
quote! {#item_index => Ok(Item::U16(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("i16") => {
quote! {#item_index => Ok(Item::I16(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("u32") => {
quote! {#item_index => Ok(Item::U32(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("i32") => {
quote! {#item_index => Ok(Item::I32(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("u64") => {
quote! {#item_index => Ok(Item::U64(self.#item_name)),}
}
Type::Path(TypePath { path, .. }) if path.is_ident("i64") => {
quote! {#item_index => Ok(Item::I64(self.#item_name)),}
}
_ => panic!("{:?} uses unsupported type {:?}", item_name, item_type),
};
matches.extend(ts);
item_index += 1;
}
let output = quote! {
#[automatically_derived]
impl #model_name {
pub fn get(&self, index: u16) -> Result<Item, Error> {
match index {
#matches
_ => Err(Error::BadIndex),
}
}
}
};
Ok(output)
}

I'm not going to give a complete answer as my proc-macro skills are non-existant, but I don't think the macro part is tricky once you've got the structure right.
The way I'd approach this is to define a trait that all the types will use. I'm going to call this Indexible which is probably bad. The point of the trait is to provide the get function and a count of all fields contained within this object.
trait Indexible {
fn nfields(&self) -> usize;
fn get(&self, idx:usize) -> Result<Item>;
}
I'm using fn nfields(&self) -> usize rather than fn nfields() -> usize as taking &self means I can use this on vectors and slices and probably some other types (It also makes the following code slightly neater).
Next you need to implement this trait for your base types:
impl Indexible for u8 {
fn nfields(&self) -> usize { 1 }
fn get(&self, idx:usize) -> Result<Item> { Ok(Item::U8(*self)) }
}
...
Generating all these is probably a good use for a macro (but the proc macro that you're talking about).
Next, you need to generate these for your desired types: My implementations look like this:
impl Indexible for NestedData {
fn nfields(&self) -> usize {
self.a.nfields() +
self.b.nfields()
}
fn get(&self, idx:usize) -> Result<Item> {
let idx = idx;
// member a
if idx < self.a.nfields() {
return self.a.get(idx)
}
let idx = idx - self.a.nfields();
// member b
if idx < self.b.nfields() {
return self.b.get(idx)
}
Err(())
}
}
impl Indexible for Data {
fn nfields(&self) -> usize {
self.a.nfields() +
self.b.nfields() +
self.c.nfields()
}
fn get(&self, idx:usize) -> Result<Item> {
let idx = idx;
if idx < self.a.nfields() {
return self.a.get(idx)
}
let idx = idx - self.a.nfields();
if idx < self.b.nfields() {
return self.b.get(idx)
}
let idx = idx - self.b.nfields();
if idx < self.c.nfields() {
return self.c.get(idx)
}
Err(())
}
}
You can see a complete running version in the playground.
These look like they can be easily generated by a macro.
If you want slightly better error messages on types that wont work, you should explicitly trea each member as an Indexible like this: (self.a as Indexible).get(..).
It might seem that this is not going to be particularly efficient, but the compiler is able to determine that most of these pieces are constant and inline them. For example, using rust 1.51 with -C opt-level=3, the following function
pub fn sum(data: &Data) -> usize {
let mut sum = 0;
for i in 0..data.nfields() {
sum += match data.get(i) {
Err(_) => panic!(),
Ok(Item::U8(v)) => v as usize,
Ok(Item::U16(v)) => v as usize,
Ok(Item::I32(v)) => v as usize,
Ok(Item::U64(v)) => v as usize,
_ => panic!(),
}
}
sum
}
compiles to just this
example::sum:
movsxd rax, dword ptr [rdi + 8]
movsxd rcx, dword ptr [rdi + 12]
movzx edx, word ptr [rdi + 16]
add rax, qword ptr [rdi]
add rax, rdx
add rax, rcx
ret
You can see this in the compiler explorer

Related

Want to pass a parameterized enum to a function using _ as parameter (like my_fun(MyEnum::Type(_)))

I have this next_expected_kind method that return the next item of an Iterable<Kind> if it is the expected type, or an error if not.
It works fine for non parameterized types like Kind1, but I don't know how to use it if the type that needs parameters like Kind2.
Something like:
let _val = match s.next_expected_kind(Kind::Kind2(str)) {
Ok(k) => str,
_ => panic!("error"),
};
Is there any tricky to make it?
https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=d21d5cff42fcca633e95b4915ce2bf1d
#[derive(PartialEq, Eq)]
enum Kind {
Kind1,
Kind2(String),
}
struct S {
kinds: std::vec::IntoIter<Kind>,
}
impl S {
fn next_expected_kind(&mut self, expected: Kind) -> Result<Kind, &str> {
match self.kinds.next() {
Some(k) if k == expected => Ok(k),
_ => Err("not expected"),
}
}
}
fn main() {
let mut s = S {
kinds: vec![Kind::Kind1, Kind::Kind2(String::from("2"))].into_iter(),
};
_ = s.next_expected_kind(Kind::Kind1);
// let _val = s.next_expected_kind(Kind::Kind2(str));
let _val = match s.kinds.next() {
Some(Kind::Kind2(str)) => str,
_ => panic!("not expected"),
};
}
You could use std::mem::discriminant() like this:
use std::mem::{Discriminant, discriminant};
#[derive(Debug, PartialEq, Eq)]
enum Kind {
Kind1,
Kind2(String),
}
struct S {
kinds: std::vec::IntoIter<Kind>,
}
impl S {
fn next_expected_kind(&mut self, expected: Discriminant<Kind>) -> Result<Kind, &str> {
match self.kinds.next() {
Some(k) if discriminant(&k) == expected => Ok(k),
_ => Err("not expected"),
}
}
}
fn main() {
let mut s = S {
kinds: vec![Kind::Kind1, Kind::Kind2(String::from("2"))].into_iter(),
};
_ = dbg!(s.next_expected_kind(discriminant(&Kind::Kind1)));
let _val = dbg!(s.next_expected_kind(discriminant(&Kind::Kind2(String::new()))));
}
The obvious drawback being that you'll have to create an instance with "empty" or default data wherever you want to call it.
The only other way I can think of would be to write a macro since you can't pass just the "variant" of an enum around.
#[derive(Debug, PartialEq, Eq)]
enum Kind {
Kind1(String),
Kind2(i32, i32),
}
struct S {
kinds: std::vec::IntoIter<Kind>,
}
macro_rules! next_expected_kind {
($self:expr, $expected:path) => {
match $self.kinds.next() {
Some(k) if matches!(k, $expected(..)) => Ok(k),
_ => Err("not expected"),
}
}
}
fn main() {
let mut s = S {
kinds: vec![Kind::Kind1(String::from("1")), Kind::Kind2(2,3)].into_iter(),
};
_ = dbg!(next_expected_kind!(&mut s, Kind::Kind1));
let _val = dbg!(next_expected_kind!(&mut s, Kind::Kind2));
}
Note: this has the limitation that all variants have to be tuple variants or struct variants and it's a bit clunky to use.

How do I extract information about the type in a derive macro?

I am implementing a derive macro to reduce the amount of boilerplate I have to write for similar types.
I want the macro to operate on structs which have the following format:
#[derive(MyTrait)]
struct SomeStruct {
records: HashMap<Id, Record>
}
Calling the macro should generate an implementation like so:
impl MyTrait for SomeStruct {
fn foo(&self, id: Id) -> Record { ... }
}
So I understand how to generate the code using quote:
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: TokenStream) -> TokenStream {
...
let generated = quote!{
impl MyTrait for #struct_name {
fn foo(&self, id: #id_type) -> #record_type { ... }
}
}
...
}
But what is the best way to get #struct_name, #id_type and #record_type from the input token stream?
One way is to use the venial crate to parse the TokenStream.
use proc_macro2;
use quote::quote;
use venial;
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Ensure it's deriving for a struct.
let s = match venial::parse_declaration(proc_macro2::TokenStream::from(item)) {
Ok(venial::Declaration::Struct(s)) => s,
Ok(_) => panic!("Can only derive this trait on a struct"),
Err(_) => panic!("Error parsing into valid Rust"),
};
let struct_name = s.name;
// Get the struct's first field.
let fields = s.fields;
let named_fields = match fields {
venial::StructFields::Named(named_fields) => named_fields,
_ => panic!("Expected a named field"),
};
let inners: Vec<(venial::NamedField, proc_macro2::Punct)> = named_fields.fields.inner;
if inners.len() != 1 {
panic!("Expected exactly one named field");
}
// Get the name and type of the first field.
let first_field_name = &inners[0].0.name;
let first_field_type = &inners[0].0.ty;
// Extract Id and Record from the type HashMap<Id, Record>
if first_field_type.tokens.len() != 6 {
panic!("Expected type T<R, S> for first named field");
}
let id = first_field_type.tokens[2].clone();
let record = first_field_type.tokens[4].clone();
// Implement MyTrait.
let generated = quote! {
impl MyTrait for #struct_name {
fn foo(&self, id: #id) -> #record { *self.#first_field_name.get(&id).unwrap() }
}
};
proc_macro::TokenStream::from(generated)
}

How to return different tokens during tokenization?

I'm writing a compiler in Rust and, given a string, part of the logic is to find out of which "kind" the characters are.
I want to return the "value" of each character. For an input of 1 + 2 each character has a "token" and should return something like:
NumberToken, 1
WhiteSpaceToken, ' '
PlusToken, '+'
WhiteSpaceToken, ' '
NumberToken, 1
My function should return something like
enum SyntaxKind {
NumberToken,
WhiteSpaceToken,
PlusToken
}
struct SyntaxToken {
kind: SyntaxKind,
value: // Some general type
}
fn next_token(line: String) -> SyntaxToken {
// Logic goes here
}
How would I implement such logic?
If you're writing out a tokenizer and you wonder what such logic might look like, you can couple these values in the same enum, e.g:
#[derive(Debug)]
enum Token {
Add,
Sub,
Whitespace,
Number(f64),
}
For more, see The Rust Programming Language, "Defining an Enum" on adding data to variants.
… and then you can use a match inside of an iterator to handle it accordingly:
#[derive(Debug)]
enum Token {
Add,
Sub,
Whitespace,
Number(f64),
}
use std::str::Chars;
use std::iter::Peekable;
struct Tokens<'a> {
source: Peekable<Chars<'a>>,
}
pub type TokenIterator<'a> = Peekable<Tokens<'a>>;
impl<'a> Tokens<'a> {
pub fn new(s: &'a str) -> TokenIterator {
Self {
source: s.chars().peekable(),
}
.peekable()
}
}
impl<'a> Iterator for Tokens<'a> {
type Item = Token;
fn next(&mut self) -> Option<Self::Item> {
match self.source.next() {
Some(' ') => Some(Token::Whitespace),
Some('+') => Some(Token::Add),
Some('-') => Some(Token::Sub),
n # Some('0'..='9') => {
let mut number = String::from(n.unwrap());
while let Some(n) = self.source.next_if(char::is_ascii_digit) {
number.push(n);
}
Some(Token::Number(number.parse::<f64>().unwrap()))
}
Some(_) => unimplemented!(),
None => None,
}
}
}
fn main() {
let tokens = Tokens::new("1 + 2");
for token in tokens {
println!("{:?}", token);
}
}
This should then give you:
Number(1.0)
Whitespace
Add
Whitespace
Number(2.0)
Playground

Union-Find implementation does not update parent tags

I'm trying to create some sets of Strings and then merge some of these sets so that they have the same tag (of type usize). Once I initialize the map, I start adding strings:
self.clusters.make_set("a");
self.clusters.make_set("b");
When I call self.clusters.find("a") and self.clusters.find("b"), different values are returned, which is fine because I haven't merged the sets yet. Then I call the following method to merge two sets
let _ = self.clusters.union("a", "b");
If I call self.clusters.find("a") and self.clusters.find("b") now, I get the same value. However, when I call the finalize() method and try to iterate through the map, the original tags are returned, as if I never merged the sets.
self.clusters.finalize();
for (address, tag) in &self.clusters.map {
self.clusterizer_writer.write_all(format!("{};{}\n", address,
self.clusters.parent[*tag]).as_bytes()).unwrap();
}
// to output all keys with the same tag as a list.
let a: Vec<(usize, Vec<String>)> = {
let mut x = HashMap::new();
for (k, v) in self.clusters.map.clone() {
x.entry(v).or_insert_with(Vec::new).push(k)
}
x.into_iter().collect()
};
I can't figure out why this is the case, but I'm relatively new to Rust; maybe its an issue with pointers?
Instead of "a" and "b", I'm actually using something like utils::arr_to_hex(&input.outpoint.txid) of type String.
This is the Rust implementation of the Union-Find algorithm that I am using:
/// Tarjan's Union-Find data structure.
#[derive(RustcDecodable, RustcEncodable)]
pub struct DisjointSet<T: Clone + Hash + Eq> {
set_size: usize,
parent: Vec<usize>,
rank: Vec<usize>,
map: HashMap<T, usize>, // Each T entry is mapped onto a usize tag.
}
impl<T> DisjointSet<T>
where
T: Clone + Hash + Eq,
{
pub fn new() -> Self {
const CAPACITY: usize = 1000000;
DisjointSet {
set_size: 0,
parent: Vec::with_capacity(CAPACITY),
rank: Vec::with_capacity(CAPACITY),
map: HashMap::with_capacity(CAPACITY),
}
}
pub fn make_set(&mut self, x: T) {
if self.map.contains_key(&x) {
return;
}
let len = &mut self.set_size;
self.map.insert(x, *len);
self.parent.push(*len);
self.rank.push(0);
*len += 1;
}
/// Returns Some(num), num is the tag of subset in which x is.
/// If x is not in the data structure, it returns None.
pub fn find(&mut self, x: T) -> Option<usize> {
let pos: usize;
match self.map.get(&x) {
Some(p) => {
pos = *p;
}
None => return None,
}
let ret = DisjointSet::<T>::find_internal(&mut self.parent, pos);
Some(ret)
}
/// Implements path compression.
fn find_internal(p: &mut Vec<usize>, n: usize) -> usize {
if p[n] != n {
let parent = p[n];
p[n] = DisjointSet::<T>::find_internal(p, parent);
p[n]
} else {
n
}
}
/// Union the subsets to which x and y belong.
/// If it returns Ok<u32>, it is the tag for unified subset.
/// If it returns Err(), at least one of x and y is not in the disjoint-set.
pub fn union(&mut self, x: T, y: T) -> Result<usize, ()> {
let x_root;
let y_root;
let x_rank;
let y_rank;
match self.find(x) {
Some(x_r) => {
x_root = x_r;
x_rank = self.rank[x_root];
}
None => {
return Err(());
}
}
match self.find(y) {
Some(y_r) => {
y_root = y_r;
y_rank = self.rank[y_root];
}
None => {
return Err(());
}
}
// Implements union-by-rank optimization.
if x_root == y_root {
return Ok(x_root);
}
if x_rank > y_rank {
self.parent[y_root] = x_root;
return Ok(x_root);
} else {
self.parent[x_root] = y_root;
if x_rank == y_rank {
self.rank[y_root] += 1;
}
return Ok(y_root);
}
}
/// Forces all laziness, updating every tag.
pub fn finalize(&mut self) {
for i in 0..self.set_size {
DisjointSet::<T>::find_internal(&mut self.parent, i);
}
}
}
I think you're just not extracting the information out of your DisjointSet struct correctly.
I got sniped by this and implemented union find. First, with a basic usize implemention:
pub struct UnionFinderImpl {
parent: Vec<usize>,
}
Then with a wrapper for more generic types:
pub struct UnionFinder<T: Hash> {
rev: Vec<Rc<T>>,
fwd: HashMap<Rc<T>, usize>,
uf: UnionFinderImpl,
}
Both structs implement a groups() method that returns a Vec<Vec<>> of groups. Clone isn't required because I used Rc.
Playground

Iterating over the contents of an Option, or over a specific value

Let's say that we have the following C-code (assume that srclen == dstlen and the length is divisible by 64).
void stream(uint8_t *dst, uint8_t *src, size_t dstlen) {
int i;
uint8_t block[64];
while (dstlen > 64) {
some_function_that_initializes_block(block);
for (i=0; i<64; i++) {
dst[i] = ((src != NULL)?src[i]:0) ^ block[i];
}
dst += 64;
dstlen -= 64;
if (src != NULL) { src += 64; }
}
}
That is a function that takes a source and a destination and xors source with some value that
the function computes. When source is set to a NULL-pointer dst is just the computed value.
In rust it is quite simple to do this when src cannot be null, we can do something like:
fn stream(dst: &mut [u8], src: &[u8]) {
let mut block = [0u8, ..64];
for (dstchunk, srcchunk) in dst.chunks_mut(64).zip(src.chunks(64)) {
some_function_that_initializes_block(block);
for (d, (&s, &b)) in dstchunk.iter_mut().zip(srcchunk.iter().zip(block.iter())) {
*d = s ^ b;
}
}
}
However let us assume that we want to be able to mimic the original C-function. Then we would like to do something like:
fn stream(dst: &mut[u8], osrc: Option<&[u8]>) {
let srciter = match osrc {
None => repeat(0),
Some(src) => src.iter()
};
// the rest of the code as above
}
Alas, this won't work since repeat(0) and src.iter() have different types. However it doesn't seem possible to solve this by using a trait object since we get a compiler error saying cannot convert to a trait object because trait 'core::iter::Iterator' is not object safe. (also there is no function in the standard library that chunks an iterator).
Is there any nice way to solve this, or should I just duplicate the code in each arm of the match statement?
Instead of repeating the code in each arm, you can call a generic inner function:
fn stream(dst: &mut[u8], osrc: Option<&[u8]>) {
fn inner<T>(dst: &mut[u8], srciter: T) where T: Iterator<u8> {
let mut block = [0u8, ..64];
//...
}
match osrc {
None => inner(dst, repeat(0)),
Some(src) => inner(dst, src.iter().map(|a| *a))
}
}
Note the additional map to make both iterators compatible (Iterator<u8>).
As you mentioned, Iterator doesn't have a built-in way to do chunking. Let's incorporate Vladimir's solution and use an iterator over chunks:
fn stream(dst: &mut[u8], osrc: Option<&[u8]>) {
const CHUNK_SIZE: uint = 64;
fn inner<'a, T>(dst: &mut[u8], srciter: T) where T: Iterator<&'a [u8]> {
let mut block = [0u8, ..CHUNK_SIZE];
for (dstchunk, srcchunk) in dst.chunks_mut(CHUNK_SIZE).zip(srciter) {
some_function_that_initializes_block(block);
for (d, (&s, &b)) in dstchunk.iter_mut().zip(srcchunk.iter().zip(block.iter())) {
*d = s ^ b;
}
}
}
static ZEROES: &'static [u8] = &[0u8, ..CHUNK_SIZE];
match osrc {
None => inner(dst, repeat(ZEROES)),
Some(src) => inner(dst, src.chunks(CHUNK_SIZE))
}
}
Unfortunately, it is impossible to use different iterators directly or with trait objects (which have recently been changed to disallow instantiation of trait objects with inappropriate methods i.e. ones which use Self type in their signature). There is a workaround for your particular case, however. Just use enums:
fn stream(dst: &mut [u8], src: Option<&[u8]>) {
static EMPTY: &'static [u8] = &[0u8, ..64]; // '
enum DifferentIterators<'a> { // '
FromSlice(std::slice::Chunks<'a, u8>), // '
FromRepeat(std::iter::Repeat<&'a [u8]>) // '
}
impl<'a> Iterator<&'a [u8]> for DifferentIterators<'a> { // '
#[inline]
fn next(&mut self) -> Option<&'a [u8]> { // '
match *self {
FromSlice(ref mut i) => i.next(),
FromRepeat(ref mut i) => i.next()
}
}
}
let srciter = match src {
None => FromRepeat(repeat(EMPTY)),
Some(src) => FromSlice(src.chunks(64))
};
let mut block = [0u8, ..64];
for (dstchunk, srcchunk) in dst.chunks_mut(64).zip(srciter) {
some_function_that_initializes_block(block);
for (d, (&s, &b)) in dstchunk.iter_mut().zip(srcchunk.iter().zip(block.iter())) {
*d = s ^ b;
}
}
}
This is a lot of code, unfortunately, but in return it is more safe and less error-prone than the C version. It is also possible to optimize it in order not to require repeat() at all:
fn stream(dst: &mut [u8], src: Option<&[u8]>) {
static EMPTY: &'static [u8] = &[0u8, ..64]; // '
enum DifferentIterators<'a> { // '
FromSlice(std::slice::Chunks<'a, u8>), // '
AlwaysZeros
}
impl<'a> Iterator<&'a [u8]> for DifferentIterators<'a> { // '
#[inline]
fn next(&mut self) -> Option<&'a [u8]> { // '
match *self {
FromSlice(ref mut i) => i.next(),
AlwaysZeros => Some(STATIC),
}
}
}
let srciter = match src {
None => AlwaysZeros,
Some(src) => FromSlice(src.chunks(64))
};
let mut block = [0u8, ..64];
for (dstchunk, srcchunk) in dst.chunks_mut(64).zip(srciter) {
some_function_that_initializes_block(block);
for (d, (&s, &b)) in dstchunk.iter_mut().zip(srcchunk.iter().zip(block.iter())) {
*d = s ^ b;
}
}
}

Resources