How do I extract information about the type in a derive macro? - rust

I am implementing a derive macro to reduce the amount of boilerplate I have to write for similar types.
I want the macro to operate on structs which have the following format:
#[derive(MyTrait)]
struct SomeStruct {
records: HashMap<Id, Record>
}
Calling the macro should generate an implementation like so:
impl MyTrait for SomeStruct {
fn foo(&self, id: Id) -> Record { ... }
}
So I understand how to generate the code using quote:
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: TokenStream) -> TokenStream {
...
let generated = quote!{
impl MyTrait for #struct_name {
fn foo(&self, id: #id_type) -> #record_type { ... }
}
}
...
}
But what is the best way to get #struct_name, #id_type and #record_type from the input token stream?

One way is to use the venial crate to parse the TokenStream.
use proc_macro2;
use quote::quote;
use venial;
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Ensure it's deriving for a struct.
let s = match venial::parse_declaration(proc_macro2::TokenStream::from(item)) {
Ok(venial::Declaration::Struct(s)) => s,
Ok(_) => panic!("Can only derive this trait on a struct"),
Err(_) => panic!("Error parsing into valid Rust"),
};
let struct_name = s.name;
// Get the struct's first field.
let fields = s.fields;
let named_fields = match fields {
venial::StructFields::Named(named_fields) => named_fields,
_ => panic!("Expected a named field"),
};
let inners: Vec<(venial::NamedField, proc_macro2::Punct)> = named_fields.fields.inner;
if inners.len() != 1 {
panic!("Expected exactly one named field");
}
// Get the name and type of the first field.
let first_field_name = &inners[0].0.name;
let first_field_type = &inners[0].0.ty;
// Extract Id and Record from the type HashMap<Id, Record>
if first_field_type.tokens.len() != 6 {
panic!("Expected type T<R, S> for first named field");
}
let id = first_field_type.tokens[2].clone();
let record = first_field_type.tokens[4].clone();
// Implement MyTrait.
let generated = quote! {
impl MyTrait for #struct_name {
fn foo(&self, id: #id) -> #record { *self.#first_field_name.get(&id).unwrap() }
}
};
proc_macro::TokenStream::from(generated)
}

Related

Derive macro generation

I'm making my own Serializable trait, in the context of a client / server system.
My idea was that the messages sent by the system is an enum made by the user of this system, so it can be customize as needed.
Too ease implementing the trait on the enum, I would like to use the #[derive(Serializable)] method, as implementing it is always the same thing.
Here is the trait :
pub trait NetworkSerializable {
fn id(&self) -> usize;
fn size(&self) -> usize;
fn serialize(self) -> Vec<u8>;
fn deserialize(id: usize, data: Vec<u8>) -> Self;
}
Now, I've tried to look at the book (this one too) and this example to try to wrap my head around derive macros, but I'm really struggling to understand them and how to implement them. I've read about token streams and abstract trees, and I think I understand the basics.
Let's take the example of the id() method : it should gives a unique id for each variant of the enum, to allow headers of messages to tell which message is incoming.
let's say I have this enum as a message system :
enum NetworkMessages {
ErrorMessage,
SpawnPlayer(usize, bool, Transform), // player id, is_mine, position
MovePlayer(usize, Transform), // player id, new_position
DestroyPlayer(usize) // player_id
}
Then, the id() function should look like this :
fn id(&self) -> usize {
match &self {
&ErrorMessage => 0,
&SpawnPlayer => 1,
&MovePlayer => 2,
&DestroyPlayer => 3,
}
}
Here was my go with writting this using a derive macro :
#[proc_macro_derive(NetworkSerializable)]
pub fn network_serializable_derive(input: TokenStream) -> TokenStream {
// Construct a representation of Rust code as a syntax tree
// that we can manipulate
let ast = syn::parse(input).unwrap();
// Build the trait implementation
impl_network_serializable_macro(&ast)
}
fn impl_network_serializable_macro(ast: &syn::DeriveInput) -> TokenStream {
// get enum name
let ref name = ast.ident;
let ref data = ast.data;
let (id_func, size_func, serialize_func, deserialize_func) = match data {
// Only if data is an enum, we do parsing
Data::Enum(data_enum) => {
// Iterate over enum variants
let mut id_func_internal = TokenStream2::new();
let mut variant_id: usize = 0;
for variant in &data_enum.variants {
// add the branch for the variant
id_func_internal.extend(quote_spanned!{
variant.span() => &variant_id,
});
variant_id += 1;
}
(id_func_internal, (), (), ())
}
_ => {(TokenStream2::new(), (), (), ())},
};
let expanded = quote! {
impl NetworkSerializable for #name {
// variant_checker_functions gets replaced by all the functions
// that were constructed above
fn size(&self) -> usize {
match &self {
#id_func
}
}
/*
#size_func
#serialize_func
#deserialize_func
*/
}
};
expanded.into()
}
So this is generating quite a lot of errors, with the "proc macro NetworkSerializable not expanded: no proc macro dylib present" being first. So I'm guessing there a lot of misunderstaning from my part in here.

Implementing Strategy pattern in rust without knowing which strategy are we using at compile time

I've been trying to implement a Strategy pattern in rust, but I'm having trouble understanding how to make it work.
So let's imagine we have a trait Adder and Element:
pub trait Element {
fn to_string(&self) -> String;
}
pub trait Adder {
type E: Element;
fn add (&self, a: &Self::E, b: &Self::E) -> Self::E;
}
And we have two implementations StringAdder with StringElements and UsizeAdder with UsizeElements:
// usize
pub struct UsizeElement {
pub value: usize
}
impl Element for UsizeElement {
fn to_string(&self) -> String {
self.value.to_string()
}
}
pub struct UsizeAdder {
}
impl Adder for UsizeAdder{
type E = UsizeElement;
fn add(&self, a: &UsizeElement, b: &UsizeElement) -> UsizeElement{
UsizeElement { value: a.value + b.value }
}
}
// String
pub struct StringElement {
pub value: String
}
impl Element for StringElement {
fn to_string(&self) -> String {
self.value.to_string()
}
}
pub struct StringAdder {
}
impl Adder for StringAdder {
type E = StringElement;
fn add(&self, a: &StringElement, b: &StringElement) -> StringElement {
let a: usize = a.value.parse().unwrap();
let b: usize = b.value.parse().unwrap();
StringElement {
value: (a + b).to_string()
}
}
}
And I want to write a code that uses trait methods from Adder trait and it's corresponding elements without knowing at compile time which strategy is going to be used.
fn main() {
let policy = "usize";
let element = "1";
let adder = get_adder(&policy);
let element_a = get_element(&policy, element);
let result = adder.add(element_a, element_a);
}
To simplify I'm going to assign a string to policy and element but normally that would be read from a file.
Is the only way to implement get_adder and get_element using dynamic dispatch? And by extension should I define Adder and Element traits to use trait objects and or the Any trait?
Edit: Here is what I managed to figure out so far.
An example of possible implementation is using match to help define concrete types for the compiler.
fn main() {
let policy = "string";
let element = "1";
let secret_key = "5";
let result = cesar(policy, element, secret_key);
dbg!(result.to_string());
}
fn cesar(policy: &str, element: &str, secret_key: &str) -> Box<dyn Element>{
match policy {
"usize" => {
let adder = UsizeAdder{};
let element = UsizeElement{ value: element.parse().unwrap() };
let secret_key = UsizeElement{ value: secret_key.parse().unwrap() };
Box::new(cesar_impl(&adder, &element, &secret_key))
}
"string" => {
let adder = StringAdder{};
let element = StringElement{ value: element.to_string() };
let secret_key = StringElement{ value: secret_key.to_string() };
Box::new(cesar_impl(&adder, &element, &secret_key))
}
_ => {
panic!("Policy not supported!")
}
}
}
fn cesar_impl<A>(adder: &A, element: &A::E, secret_key: &A::E) -> A::E where A: Adder, A::E : Element {
adder.add(&element, &secret_key)
}
However the issue is that I have to wrap every function I want to implement using a match function to determine the concrete type, and also case for every policy available.
It does not seem like the proper way of implementing it as it will bloat the code, make it more error prone and less maintainable unless I end up using macros.
Edit 2: Here you can find an example using dynamic dispatch. However I'm not convinced it's the proper way to implement the solution.
Example using dynamic dispatch
Thank you for your help :)

Is there a way to use internal symbols in a Rust macro?

Suppose my library definespub struct MyStruct { a: i32 }. Then, suppose I define a procedural macro my_macro!. Perhaps I want my_macro!(11) to expand to MyStruct { a: 11 }. This proves that I must have used my_macro! to obtain MyStruct, since the a field is not public.
This is what I want to do:
lib.rs:
struct MyStruct {
a: i32,
}
#[proc_macro]
pub fn my_macro(input: TokenStream) -> TokenStream {
let ast: syn::Expr = syn::parse(input).unwrap();
impl_my_macro(ast)
}
fn impl_my_macro(expr: syn::Expr) -> TokenStream {
let gen = match expr {
syn::Expr::Lit(expr_lit) => match expr_lit.lit {
syn::Lit::Int(lit_int) => {
let value = lit_int.base10_parse::<i32>().unwrap();
if value > 100 {
quote::quote! {
compile_error("Integer literal is too big.");
}
} else {
quote::quote! {
MyStruct{a: #lit_int}
}
}
}
_ => quote::quote! {
compile_error!("Expected an integer literal.");
},
},
_ => quote::quote! {
compile_error!("Expected an integer literal.");
},
};
gen.into()
}
And I would use it like this:
fn test_my_macro() {
let my_struct = secret_macro::my_macro!(10);
}
But the compiler gives this warning:
error[E0422]: cannot find struct, variant or union type `MyStruct` in this scope
--> tests/test.rs:14:21
|
14 | secret_macro::my_macro!(10);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope
Is it possible to achieve this in Rust?

How to return different tokens during tokenization?

I'm writing a compiler in Rust and, given a string, part of the logic is to find out of which "kind" the characters are.
I want to return the "value" of each character. For an input of 1 + 2 each character has a "token" and should return something like:
NumberToken, 1
WhiteSpaceToken, ' '
PlusToken, '+'
WhiteSpaceToken, ' '
NumberToken, 1
My function should return something like
enum SyntaxKind {
NumberToken,
WhiteSpaceToken,
PlusToken
}
struct SyntaxToken {
kind: SyntaxKind,
value: // Some general type
}
fn next_token(line: String) -> SyntaxToken {
// Logic goes here
}
How would I implement such logic?
If you're writing out a tokenizer and you wonder what such logic might look like, you can couple these values in the same enum, e.g:
#[derive(Debug)]
enum Token {
Add,
Sub,
Whitespace,
Number(f64),
}
For more, see The Rust Programming Language, "Defining an Enum" on adding data to variants.
… and then you can use a match inside of an iterator to handle it accordingly:
#[derive(Debug)]
enum Token {
Add,
Sub,
Whitespace,
Number(f64),
}
use std::str::Chars;
use std::iter::Peekable;
struct Tokens<'a> {
source: Peekable<Chars<'a>>,
}
pub type TokenIterator<'a> = Peekable<Tokens<'a>>;
impl<'a> Tokens<'a> {
pub fn new(s: &'a str) -> TokenIterator {
Self {
source: s.chars().peekable(),
}
.peekable()
}
}
impl<'a> Iterator for Tokens<'a> {
type Item = Token;
fn next(&mut self) -> Option<Self::Item> {
match self.source.next() {
Some(' ') => Some(Token::Whitespace),
Some('+') => Some(Token::Add),
Some('-') => Some(Token::Sub),
n # Some('0'..='9') => {
let mut number = String::from(n.unwrap());
while let Some(n) = self.source.next_if(char::is_ascii_digit) {
number.push(n);
}
Some(Token::Number(number.parse::<f64>().unwrap()))
}
Some(_) => unimplemented!(),
None => None,
}
}
}
fn main() {
let tokens = Tokens::new("1 + 2");
for token in tokens {
println!("{:?}", token);
}
}
This should then give you:
Number(1.0)
Whitespace
Add
Whitespace
Number(2.0)
Playground

"Registering" trait implementations + factory method for trait objects

Say we want to have objects implementations switched at runtime, we'd do something like this:
pub trait Methods {
fn func(&self);
}
pub struct Methods_0;
impl Methods for Methods_0 {
fn func(&self) {
println!("foo");
}
}
pub struct Methods_1;
impl Methods for Methods_1 {
fn func(&self) {
println!("bar");
}
}
pub struct Object<'a> { //'
methods: &'a (Methods + 'a),
}
fn main() {
let methods: [&Methods; 2] = [&Methods_0, &Methods_1];
let mut obj = Object { methods: methods[0] };
obj.methods.func();
obj.methods = methods[1];
obj.methods.func();
}
Now, what if there are hundreds of such implementations? E.g. imagine implementations of cards for collectible card game where every card does something completely different and is hard to generalize; or imagine implementations for opcodes for a huge state machine. Sure you can argue that a different design pattern can be used -- but that's not the point of this question...
Wonder if there is any way for these Impl structs to somehow "register" themselves so they can be looked up later by a factory method? I would be happy to end up with a magical macro or even a plugin to accomplish that.
Say, in D you can use templates to register the implementations -- and if you can't for some reason, you can always inspect modules at compile-time and generate new code via mixins; there are also user-defined attributes that can help in this. In Python, you would normally use a metaclass so that every time a new child class is created, a ref to it is stored in the metaclass's registry which allows you to look up implementations by name or parameter; this can also be done via decorators if implementations are simple functions.
Ideally, in the example above you would be able to create Object as
Object::new(0)
where the value 0 is only known at runtime and it would magically return you an Object { methods: &Methods_0 }, and the body of new() would not have the implementations hard-coded like so "methods: [&Methods; 2] = [&Methods_0, &Methods_1]", instead it should be somehow inferred automatically.
So, this is probably extremely buggy, but it works as a proof of concept.
It is possible to use Cargo's code generation support to make the introspection at compile-time, by parsing (not exactly parsing in this case, but you get the idea) the present implementations, and generating the boilerplate necessary to make Object::new() work.
The code is pretty convoluted and has no error handling whatsoever, but works.
Tested on rustc 1.0.0-dev (2c0535421 2015-02-05 15:22:48 +0000)
(See on github)
src/main.rs:
pub mod implementations;
mod generated_glue {
include!(concat!(env!("OUT_DIR"), "/generated_glue.rs"));
}
use generated_glue::Object;
pub trait Methods {
fn func(&self);
}
pub struct Methods_2;
impl Methods for Methods_2 {
fn func(&self) {
println!("baz");
}
}
fn main() {
Object::new(2).func();
}
src/implementations.rs:
use super::Methods;
pub struct Methods_0;
impl Methods for Methods_0 {
fn func(&self) {
println!("foo");
}
}
pub struct Methods_1;
impl Methods for Methods_1 {
fn func(&self) {
println!("bar");
}
}
build.rs:
#![feature(core, unicode, path, io, env)]
use std::env;
use std::old_io::{fs, File, BufferedReader};
use std::collections::HashMap;
fn main() {
let target_dir = Path::new(env::var_string("OUT_DIR").unwrap());
let mut target_file = File::create(&target_dir.join("generated_glue.rs")).unwrap();
let source_code_path = Path::new(file!()).join_many(&["..", "src/"]);
let source_files = fs::readdir(&source_code_path).unwrap().into_iter()
.filter(|path| {
match path.str_components().last() {
Some(Some(filename)) => filename.split('.').last() == Some("rs"),
_ => false
}
});
let mut implementations = HashMap::new();
for source_file_path in source_files {
let relative_path = source_file_path.path_relative_from(&source_code_path).unwrap();
let source_file_name = relative_path.as_str().unwrap();
implementations.insert(source_file_name.to_string(), vec![]);
let mut file_implementations = &mut implementations[*source_file_name];
let mut source_file = BufferedReader::new(File::open(&source_file_path).unwrap());
for line in source_file.lines() {
let line_str = match line {
Ok(line_str) => line_str,
Err(_) => break,
};
if line_str.starts_with("impl Methods for Methods_") {
const PREFIX_LEN: usize = 25;
let number_len = line_str[PREFIX_LEN..].chars().take_while(|chr| {
chr.is_digit(10)
}).count();
let number: i32 = line_str[PREFIX_LEN..(PREFIX_LEN + number_len)].parse().unwrap();
file_implementations.push(number);
}
}
}
writeln!(&mut target_file, "use super::Methods;").unwrap();
for (source_file_name, impls) in &implementations {
let module_name = match source_file_name.split('.').next() {
Some("main") => "super",
Some(name) => name,
None => panic!(),
};
for impl_number in impls {
writeln!(&mut target_file, "use {}::Methods_{};", module_name, impl_number).unwrap();
}
}
let all_impls = implementations.values().flat_map(|impls| impls.iter());
writeln!(&mut target_file, "
pub struct Object;
impl Object {{
pub fn new(impl_number: i32) -> Box<Methods + 'static> {{
match impl_number {{
").unwrap();
for impl_number in all_impls {
writeln!(&mut target_file,
" {} => Box::new(Methods_{}),", impl_number, impl_number).unwrap();
}
writeln!(&mut target_file, "
_ => panic!(\"Unknown impl number: {{}}\", impl_number),
}}
}}
}}").unwrap();
}
The generated code:
use super::Methods;
use super::Methods_2;
use implementations::Methods_0;
use implementations::Methods_1;
pub struct Object;
impl Object {
pub fn new(impl_number: i32) -> Box<Methods + 'static> {
match impl_number {
2 => Box::new(Methods_2),
0 => Box::new(Methods_0),
1 => Box::new(Methods_1),
_ => panic!("Unknown impl number: {}", impl_number),
}
}
}

Resources