Related
I am attempting to write a macro that will generate a telemetry function for any struct by using
#[derive(telemetry)] . This function will send a data stream to anything provided that is io::Writable. This data stream will be "self describing" such that the receiver doesn't need to know anything else about the data other than the bytes received. This allows the receiver to be able to correctly parse a struct and print its member variables names and values, even if variables are added, removed, order changed, or variable names renamed. The telemetry function works for non-nested structs and will print the name and type of a nested struct. But I need it to recursively print the names, types, sizes, and values of the nested structs member variables. An example is shown below as is the code.
Current behavior
use derive_telemetry::Telemetry;
use std::fs::File;
use std::io::{Write};
use std::time::{Duration, Instant};
use std::marker::Send;
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
#[repr(C, packed)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub struct AnotherCustomStruct {
pub my_var_2: f64,
pub my_var_1: f32,
}
#[derive(Telemetry)]
#[derive(Debug, Serialize, Deserialize)]
struct TestStruct {
pub a: u32,
pub b: u32,
pub my_custom_struct: AnotherCustomStruct,
pub my_array: [u32; 10],
pub my_vec: Vec::<u64>,
pub another_variable : String,
}
const HEADER_FILENAME: &str = "test_file_stream.header";
const DATA_FILENAME: &str = "test_file_stream.data";
fn main() -> Result<(), Box<dyn std::error::Error>>{
let my_struct = TestStruct { a: 10,
b: 11,
my_custom_struct: AnotherCustomStruct { my_var_1: 123.456, my_var_2: 789.1023 },
my_array: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
my_vec: vec![11, 12, 13],
another_variable: "Hello".to_string()
};
let file_header_stream = Mutex::new(Box::new(File::create(HEADER_FILENAME)?) as Box <dyn std::io::Write + Send + Sync>);
let file_data_stream = Mutex::new(Box::new(File::create(DATA_FILENAME)?) as Box <dyn std::io::Write + Send + Sync>);
my_struct.telemetry(Arc::new(file_header_stream), Arc::new(file_data_stream));
let header: TelemetryHeader = bincode::deserialize_from(&File::open(HEADER_FILENAME)?)?;
let data: TestStruct = bincode::deserialize_from(&File::open(DATA_FILENAME)?)?;
println!("{:#?}", header);
println!("{:?}", data);
Ok(())
}
produces
TelemetryHeader {
variable_descriptions: [
VariableDescription {
var_name_length: 1,
var_name: "a",
var_type_length: 3,
var_type: "u32",
var_size: 4,
},
VariableDescription {
var_name_length: 1,
var_name: "b",
var_type_length: 3,
var_type: "u32",
var_size: 4,
},
VariableDescription {
var_name_length: 16,
var_name: "my_custom_struct",
var_type_length: 19,
var_type: "AnotherCustomStruct",
var_size: 12,
},
VariableDescription {
var_name_length: 8,
var_name: "my_array",
var_type_length: 10,
var_type: "[u32 ; 10]",
var_size: 40,
},
VariableDescription {
var_name_length: 6,
var_name: "my_vec",
var_type_length: 14,
var_type: "Vec :: < u64 >",
var_size: 24,
},
VariableDescription {
var_name_length: 16,
var_name: "another_variable",
var_type_length: 6,
var_type: "String",
var_size: 24,
},
],
}
TestStruct { a: 10, b: 11, my_custom_struct: AnotherCustomStruct { my_var_2: 789.1023, my_var_1: 123.456 }, my_array: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], my_vec: [11, 12, 13], another_variable: "Hello" }
The data format is length of variable name, variable name, length of variable type, variable type, variable num of bytes.
Required Behavior
TelemetryHeader {
variable_descriptions: [
VariableDescription {
var_name_length: 1,
var_name: "a",
var_type_length: 3,
var_type: "u32",
var_size: 4,
},
VariableDescription {
var_name_length: 1,
var_name: "b",
var_type_length: 3,
var_type: "u32",
var_size: 4,
},
VariableDescription {
var_name_length: 16,
var_name: "my_custom_struct",
var_type_length: 19,
var_type: "AnotherCustomStruct",
var_size: 12,
},
VariableDescription {
var_name_length: 8,
var_name: "my_var_2",
var_type_length: 3,
var_type: "f64",
var_size: 8,
},
VariableDescription {
var_name_length: 8,
var_name: "my_var_1",
var_type_length: 3,
var_type: "f32",
var_size: 4,
},
VariableDescription {
var_name_length: 8,
var_name: "my_array",
var_type_length: 10,
var_type: "[u32 ; 10]",
var_size: 40,
},
VariableDescription {
var_name_length: 6,
var_name: "my_vec",
var_type_length: 14,
var_type: "Vec :: < u64 >",
var_size: 24,
},
VariableDescription {
var_name_length: 16,
var_name: "another_variable",
var_type_length: 6,
var_type: "String",
var_size: 24,
},
],
}
TestStruct { a: 10, b: 11, my_custom_struct: AnotherCustomStruct { my_var_2: 789.1023, my_var_1: 123.456 }, my_array: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], my_vec: [11, 12, 13], another_variable: "Hello" }
To reiterate, the current behavior correctly prints the variable name, type, size, and value for a struct with the Telemetry trait derived. It prints the name, type, and size for a nested struct correctly, but it does not then print the names and values of the nested structs members which is required. The code is below, apologies for such a long post, I hope this is formatted well and clear, thank you in advance.
Directory Structure
src
-main.rs
telemetry
-Cargo.toml
-src
--lib.rs
Cargo.toml
main.rs
use derive_telemetry::Telemetry;
use std::fs::File;
use std::io::{Write};
use std::time::{Duration, Instant};
use std::marker::Send;
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
const HEADER_FILENAME: &str = "test_file_stream.header";
const DATA_FILENAME: &str = "test_file_stream.data";
#[repr(C, packed)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub struct AnotherCustomStruct {
pub my_var_2: f64,
pub my_var_1: f32,
}
#[derive(Telemetry)]
#[derive(Debug, Serialize, Deserialize)]
struct TestStruct {
pub a: u32,
pub b: u32,
pub my_custom_struct: AnotherCustomStruct,
pub my_array: [u32; 10],
pub my_vec: Vec::<u64>,
pub another_variable : String,
}
fn main() -> Result<(), Box<dyn std::error::Error>>{
let my_struct = TestStruct { a: 10,
b: 11,
my_custom_struct: AnotherCustomStruct { my_var_1: 123.456, my_var_2: 789.1023 },
my_array: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
my_vec: vec![11, 12, 13],
another_variable: "Hello".to_string()
};
let file_header_stream = Mutex::new(Box::new(File::create(HEADER_FILENAME)?) as Box <dyn std::io::Write + Send + Sync>);
let file_data_stream = Mutex::new(Box::new(File::create(DATA_FILENAME)?) as Box <dyn std::io::Write + Send + Sync>);
//let stdout_header_stream = Mutex::new(Box::new(io::stdout()) as Box <dyn std::io::Write + Send + Sync>);
//let stdout_data_stream = Mutex::new(Box::new(io::stdout()) as Box <dyn std::io::Write + Send + Sync>);
//let tcp_header_stream = Mutex::new(Box::new(TCPStream::connect(127.0.0.1)?) as Box <dyn std::io::Write + Send + Sync>);
//let tcp_data_stream = Mutex::new(Box::new(TCPStream::connect(127.0.0.1)?) as Box <dyn std::io::Write + Send + Sync>);
//let test_traits = Mutex::new(Box::new(io::stdout()) as Box <dyn std::io::Write + Send + Sync>);
let start = Instant::now();
my_struct.telemetry(Arc::new(file_header_stream), Arc::new(file_data_stream));
let duration = start.elapsed();
println!("Telemetry took: {:?}", duration);
thread::sleep(Duration::from_secs(1));
let header: TelemetryHeader = bincode::deserialize_from(&File::open(HEADER_FILENAME)?)?;
let data: TestStruct = bincode::deserialize_from(&File::open(DATA_FILENAME)?)?;
println!("{:#?}", header);
println!("{:?}", data);
Ok(())
}
main Cargo.toml
[package]
name = "proc_macro_test"
version = "0.1.0"
edition = "2018"
[workspace]
members = [
"telemetry",
]
[dependencies]
derive_telemetry = { path = "telemetry" }
ndarray = "0.15.3"
crossbeam = "*"
serde = { version = "*", features=["derive"]}
bincode = "*"
[profile.dev]
opt-level = 0
[profile.release]
opt-level = 3
telemetry lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, parse_quote, DeriveInput};
#[proc_macro_derive(Telemetry)]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let output = parse_derive_input(&input);
match output {
syn::Result::Ok(tt) => tt,
syn::Result::Err(err) => err.to_compile_error(),
}
.into()
}
fn parse_derive_input(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let struct_ident = &input.ident;
let struct_data = parse_data(&input.data)?;
let struct_fields = &struct_data.fields;
let generics = add_debug_bound(struct_fields, input.generics.clone());
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let _struct_ident_str = format!("{}", struct_ident);
let tele_body = match struct_fields {
syn::Fields::Named(fields_named) => handle_named_fields(fields_named)?,
syn::Fields::Unnamed(fields_unnamed) => {
let field_indexes = (0..fields_unnamed.unnamed.len()).map(syn::Index::from);
let field_indexes_str = (0..fields_unnamed.unnamed.len()).map(|idx| format!("{}", idx));
quote!(#( .field(#field_indexes_str, &self.#field_indexes) )*)
}
syn::Fields::Unit => quote!(),
};
let telemetry_declaration = quote!(
trait Telemetry {
fn telemetry(self, header_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>, data_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>);
}
);
syn::Result::Ok(
quote!(
use std::thread;
use std::collections::VecDeque;
#[derive(Serialize, Deserialize, Default, Debug)]
pub struct VariableDescription {
pub var_name_length: usize,
pub var_name: String,
pub var_type_length: usize,
pub var_type: String,
pub var_size: usize,
}
#[derive(Serialize, Deserialize, Default, Debug)]
pub struct TelemetryHeader {
pub variable_descriptions: VecDeque::<VariableDescription>,
}
#telemetry_declaration
impl #impl_generics Telemetry for #struct_ident #ty_generics #where_clause {
fn telemetry(self, header_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>, data_stream: Arc<Mutex::<Box <std::io::Write + std::marker::Send + Sync>>>) {
thread::spawn(move || {
#tele_body;
});
}
}
)
)
}
fn handle_named_fields(fields: &syn::FieldsNamed) -> syn::Result<proc_macro2::TokenStream> {
let idents = fields.named.iter().map(|f| &f.ident);
let types = fields.named.iter().map(|f| &f.ty);
let num_entities = fields.named.len();
let test = quote! (
let mut tele_header = TelemetryHeader {variable_descriptions: VecDeque::with_capacity(#num_entities)};
#(
tele_header.variable_descriptions.push_back( VariableDescription {
var_name_length: stringify!(#idents).len(),
var_name: stringify!(#idents).to_string(),
var_type_length: stringify!(#types).len(),
var_type: stringify!(#types).to_string(),
var_size: std::mem::size_of_val(&self.#idents),
});
)*
header_stream.lock().unwrap().write(&bincode::serialize(&tele_header).unwrap()).unwrap();
data_stream.lock().unwrap().write(&bincode::serialize(&self).unwrap()).unwrap();
);
syn::Result::Ok(test)
}
fn parse_named_field(field: &syn::Field) -> proc_macro2::TokenStream {
let ident = field.ident.as_ref().unwrap();
let ident_str = format!("{}", ident);
let ident_type = &field.ty;
if field.attrs.is_empty() {
quote!(
println!("Var Name Length: {}", stringify!(#ident_str).len());
println!("Var Name: {}", #ident_str);
println!("Var Type Length: {}", stringify!(#ident_type).len());
println!("Var Type: {}", stringify!(#ident_type));
println!("Var Val: {}", &self.#ident);
)
}
else {
//parse_named_field_attrs(field)
quote!()
}
}
fn parse_named_field_attrs(field: &syn::Field) -> syn::Result<proc_macro2::TokenStream> {
let ident = field.ident.as_ref().unwrap();
let ident_str = format!("{}", ident);
let attr = field.attrs.last().unwrap();
if !attr.path.is_ident("debug") {
return syn::Result::Err(syn::Error::new_spanned(
&attr.path,
"value must be \"debug\"",
));
}
let attr_meta = &attr.parse_meta();
match attr_meta {
Ok(syn::Meta::NameValue(syn::MetaNameValue { lit, .. })) => {
let debug_assign_value = lit;
syn::Result::Ok(quote!(
.field(#ident_str, &std::format_args!(#debug_assign_value, &self.#ident))
))
}
Ok(meta) => syn::Result::Err(syn::Error::new_spanned(meta, "expected meta name value")),
Err(err) => syn::Result::Err(err.clone()),
}
}
fn parse_data(data: &syn::Data) -> syn::Result<&syn::DataStruct> {
match data {
syn::Data::Struct(data_struct) => syn::Result::Ok(data_struct),
syn::Data::Enum(syn::DataEnum { enum_token, .. }) => syn::Result::Err(
syn::Error::new_spanned(enum_token, "CustomDebug is not implemented for enums"),
),
syn::Data::Union(syn::DataUnion { union_token, .. }) => syn::Result::Err(
syn::Error::new_spanned(union_token, "CustomDebug is not implemented for unions"),
),
}
}
fn add_debug_bound(fields: &syn::Fields, mut generics: syn::Generics) -> syn::Generics {
let mut phantom_ty_idents = std::collections::HashSet::new();
let mut non_phantom_ty_idents = std::collections::HashSet::new();
let g = generics.clone();
for (ident, opt_iter) in fields
.iter()
.flat_map(extract_ty_path)
.map(|path| extract_ty_idents(path, g.params.iter().flat_map(|p| {
if let syn::GenericParam::Type(ty) = p {
std::option::Option::Some(&ty.ident)
} else {
std::option::Option::None
}
} ).collect()))
{
if ident == "PhantomData" {
// If the field type ident is `PhantomData`
// add the generic parameters into the phantom idents collection
if let std::option::Option::Some(args) = opt_iter {
for arg in args {
phantom_ty_idents.insert(arg);
}
}
} else {
// Else, add the type and existing generic parameters into the non-phantom idents collection
non_phantom_ty_idents.insert(ident);
if let std::option::Option::Some(args) = opt_iter {
for arg in args {
non_phantom_ty_idents.insert(arg);
}
}
}
}
// Find the difference between the phantom idents and non-phantom idents
// Collect them into an hash set for O(1) lookup
let non_debug_fields = phantom_ty_idents
.difference(&non_phantom_ty_idents)
.collect::<std::collections::HashSet<_>>();
// Iterate generic params and if their ident is NOT in the phantom fields
// do not add the generic bound
for param in generics.type_params_mut() {
// this is kinda shady, hoping it works
if !non_debug_fields.contains(&¶m.ident) {
param.bounds.push(parse_quote!(std::fmt::Debug));
}
}
generics
}
/// Extract the path from the type path in a field.
fn extract_ty_path(field: &syn::Field) -> std::option::Option<&syn::Path> {
if let syn::Type::Path(syn::TypePath { path, .. }) = &field.ty {
std::option::Option::Some(&path)
} else {
std::option::Option::None
}
}
/// From a `syn::Path` extract both the type ident and an iterator over generic type arguments.
fn extract_ty_idents<'a>(
path: &'a syn::Path,
params: std::collections::HashSet<&'a syn::Ident>,
) -> (
&'a syn::Ident,
std::option::Option<impl Iterator<Item = &'a syn::Ident>>,
) {
let ty_segment = path.segments.last().unwrap();
let ty_ident = &ty_segment.ident;
if let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args, ..
}) = &ty_segment.arguments
{
let ident_iter = args.iter().flat_map(move |gen_arg| {
if let syn::GenericArgument::Type(syn::Type::Path(syn::TypePath { path, .. })) = gen_arg
{
match path.segments.len() {
2 => {
let ty = path.segments.first().unwrap();
let assoc_ty = path.segments.last().unwrap();
if params.contains(&ty.ident) {
std::option::Option::Some(&assoc_ty.ident)
} else {
std::option::Option::None
}
}
1 => path.get_ident(),
_ => std::unimplemented!("kinda tired of edge cases"),
}
} else {
std::option::Option::None
}
});
(ty_ident, std::option::Option::Some(ident_iter))
} else {
(ty_ident, std::option::Option::None)
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}
telemetry Cargo.toml
[package]
name = "derive_telemetry"
version = "0.0.0"
edition = "2018"
autotests = false
publish = false
[lib]
proc-macro = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
proc-macro2 = ">= 1.0.29"
syn = ">= 1.0.76"
quote = ">= 1.0.9"
crossbeam = "*"
serde = { version = "*", features=["derive"]}
bincode = "*"
Once again, apologies for the lengthy post. I hope this is clear and I believe this is everything to reproduce what I have and unfortunately I believe it is a minimum working example otherwise I would take more out for ease of reading/answering
I think if I could call handle_named_fields recursively on any Field that is a struct I would get the desired behavior. The only problem with this approach is that I don't see an obvious way to tell if a Field is a struct or not. If I had a syn::Data it would be trivial, but I can't see how to make to this with a syn::Field.
What you're trying is maybe a bit complicated for a SO question. I can only give an answer sketch (which would usually be a comment, but it's too long for that).
You can't parse "nested" structs in one macro run. Your derive macro gets access to the one struct it's running on, that's it. The only thing you can do is have the generated code "recursively" call other generated code (i.e. TestStruct::telemetry could call AnotherCustomStruct::telemetry).
There is one fundamental problem, though: Your macro won't know which struct members to generate a recursive call for. To solve that, you can either do the recursive call on all members, and implement Telemetry for a buttload of existing types, or you ask the user to add some #[recursive_telemetry] attribute to struct members they want included, and fire only on that.
I am trying to use serde together with bincode to de-serialize an arbitrary bitcoin network message. Given that the payload is handled ubiquitously as a byte array, how do I de-serialize it when the length is unknown at compile-time? bincode does by default handle Vec<u8> by assuming it's length is encoded as u64 right before the elements of the vector. However, this assumption does not hold here because the checksum comes after the length of the payload.
I have the following working solution
Cargo.toml
[package]
name = "serde-test"
version = "0.1.0"
edition = "2018"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
bincode = "1.3.3"
main.rs
use bincode::Options;
use serde::{Deserialize, Deserializer, de::{SeqAccess, Visitor}};
#[derive(Debug)]
struct Message {
// https://en.bitcoin.it/wiki/Protocol_documentation#Message_structure
magic: u32,
command: [u8; 12],
length: u32,
checksum: u32,
payload: Vec<u8>,
}
struct MessageVisitor;
impl<'de> Visitor<'de> for MessageVisitor {
type Value = Message;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("Message")
}
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error> where V: SeqAccess<'de>,
{
let magic = seq.next_element()?.unwrap();
let command = seq.next_element()?.unwrap();
let length: u32 = seq.next_element()?.unwrap();
let checksum = seq.next_element()?.unwrap();
let payload = (0..length).map(|_| seq.next_element::<u8>().unwrap().unwrap()).collect();
// verify payload checksum (omitted for brevity)
Ok(Message {magic, command, length, checksum, payload})
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Message, D::Error> where D: Deserializer<'de>,
{
deserializer.deserialize_tuple(5000, MessageVisitor) // <-- overallocation
}
}
fn main() {
let bytes = b"\xf9\xbe\xb4\xd9version\x00\x00\x00\x00\x00e\x00\x00\x00_\x1ai\xd2r\x11\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\xbc\x8f^T\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xc6\x1bd\t \x8d\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xcb\x00q\xc0 \x8d\x12\x805\xcb\xc9yS\xf8\x0f/Satoshi:0.9.3/\xcf\x05\x05\x00\x01";
let msg: Message = bincode::DefaultOptions::new().with_fixint_encoding().deserialize(bytes).unwrap();
println!("{:?}", msg);
}
Output:
Message { magic: 3652501241, command: [118, 101, 114, 115, 105, 111, 110, 0, 0, 0, 0, 0], length: 101, checksum: 3530103391, payload: [114, 17, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 188, 143, 94, 84, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 198, 27, 100, 9, 32, 141, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 203, 0, 113, 192, 32, 141, 18, 128, 53, 203, 201, 121, 83, 248, 15, 47, 83, 97, 116, 111, 115, 104, 105, 58, 48, 46, 57, 46, 51, 47, 207, 5, 5, 0, 1] }
I dislike this solution because of how payload is handled. It requires me to allocate some "large enough" buffer to take into account the dynamic size of the payload, In the code snippet above 5000 is sufficient. I would much rather de-serialize payload as a single element and use deserializer.deserialize_tuple(5, MessageVisitor) instead.
Is there a way to handle this kind of deserialization in a succint manner?
Similar question I could find: Can I deserialize vectors with variable length prefix with Bincode?
Your problem is that the source message is not encoded as bincode, so you are doing weird things to treat non-bincode data as if it was.
Serde is designed for creating serializers and deserializers for general-purpose formats, but your message is in a very specific format that can only be interpreted one way.
A library like nom is much more suitable for this kind of work, but it may be overkill considering how simple the format is and you can just parse it from the bytes directly:
use std::convert::TryInto;
fn main() {
let bytes = b"\xf9\xbe\xb4\xd9version\x00\x00\x00\x00\x00e\x00\x00\x00_\x1ai\xd2r\x11\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\xbc\x8f^T\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xc6\x1bd\t \x8d\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xcb\x00q\xc0 \x8d\x12\x805\xcb\xc9yS\xf8\x0f/Satoshi:0.9.3/\xcf\x05\x05\x00\x01";
let (magic_bytes, bytes) = bytes.split_at(4);
let magic = u32::from_le_bytes(magic_bytes.try_into().unwrap());
let (command_bytes, bytes) = bytes.split_at(12);
let command = command_bytes.try_into().unwrap();
let (length_bytes, bytes) = bytes.split_at(4);
let length = u32::from_le_bytes(length_bytes.try_into().unwrap());
let (checksum_bytes, bytes) = bytes.split_at(4);
let checksum = u32::from_le_bytes(checksum_bytes.try_into().unwrap());
let payload = bytes[..length as usize].to_vec();
let msg = Message {
magic,
command,
length,
checksum,
payload,
};
println!("{:?}", msg);
}
There are hundreds of cryptocurrency projects in Rust and there are many crates already written for handling cryptocurrency data structures. These crates are battle-tested and will have much better error-handling (my example above has none). As mentioned in the comments, you can perhaps look at the bitcoin crate.
Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.
We don’t allow questions seeking recommendations for books, tools, software libraries, and more. You can edit the question so it can be answered with facts and citations.
Closed 1 year ago.
Improve this question
I have this function, but I feel that it's duplicate code. I was wondering if anyone here shares how to make it more rustacean. I'm still learning Rust, and I thought this could be a good example to share.
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let start_date_parsed: i64 = Utc
.ymd(
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[0]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[1]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
FromStr::from_str(start_date.split('-').collect::<Vec<&str>>()[2]).unwrap_or_else(
|_| {
eprintln!("{}", &message);
process::exit(1);
},
),
)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
let end_date_parsed: i64 = Utc
.ymd(
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[0]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[1]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
FromStr::from_str(end_date.split('-').collect::<Vec<&str>>()[2]).unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
}),
)
.and_hms_milli(0, 0, 2, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
(start_date_parsed, end_date_parsed)
Mainly to remove the three arguments passed to Utc.ymd since they are doing the same, just using a different index, they are parsing dates such as "2021-01-01" and returning it in milliseconds and clamping it to a floor and ceiling.
Can we call "Don't reinvent the wheel" a Rustacean way? Probably..
Playground
You can use NaiveDate::parse_from_str, with a lot of formatting options.
You can then use Utc::from_utc_date or Utc::from_local_date to obtain the Date and process it later like you did
This reduces the code to:
use chrono::*;
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let start_date = NaiveDate::parse_from_str(start_date, "%Y-%m-%d")
.expect(&message);
let start_date_parsed: i64 = Utc.from_utc_date(&start_date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
let end_date = NaiveDate::parse_from_str(end_date, "%Y-%m-%d")
.expect(&message);
let end_date_parsed: i64 = Utc.from_utc_date(&end_date)
.and_hms_milli(0, 0, 2, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
(start_date_parsed, end_date_parsed)
}
That's where the question of code design comes, and it's opinion based. You see - you repeat the same code twice for processing start and end date - this code ideally should become a separate function.
use chrono::*;
fn parse_and_transform_date(date_str: &str) -> Result<i64, format::ParseError> {
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")?;
let date_parsed: i64 = Utc.from_utc_date(&date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
Ok(date_parsed)
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let start_date_result = parse_and_transform_date(start_date);
let end_date_result = parse_and_transform_date(end_date);
if let (Ok(start_date), Ok(end_date)) = (start_date_result, end_date_result) {
return (start_date, end_date); // success
}
panic!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
}
Finally I would also let you consider a few problems:
It's not a good taste to abort your whole program with an error message like that, I would suggest check_and_transform_dates to instead return a Result and then considering this result, your calling code should handle the situation properly.
check_and_transform_dates should also probably do additional checks, e.g. check that end_date is not before start_date, etc
Example of fixing:
use chrono::*;
fn parse_and_transform_date(date_str: &str) -> Result<i64, format::ParseError> {
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")?;
let date_parsed: i64 = Utc.from_utc_date(&date)
.and_hms_milli(0, 0, 1, 0)
.timestamp_millis()
.clamp(
Utc.ymd(2016, 1, 1)
.and_hms_milli(0, 0, 0, 0)
.timestamp_millis(),
Utc::now().timestamp_millis(),
);
Ok(date_parsed)
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> Option<(i64, i64)> {
let start_date_result = parse_and_transform_date(start_date);
let end_date_result = parse_and_transform_date(end_date);
match (start_date_result, end_date_result) {
(Ok(s), Ok(e)) if (s <= e) => Some((s, e)),
_ => None
}
}
fn main() {
println!("{:?}", check_and_transform_dates("2020-01-02", "2020-01-03")
.expect("Data could not be downloaded"));
}
I'm going with this refactoring
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let earliest: NaiveDate = NaiveDate::from_ymd(2016, 1, 1);
let today: NaiveDate = Utc::today().naive_utc();
let parse_date = |date: &str| -> NaiveDate {
let date: NaiveDate = NaiveDate::parse_from_str(date, "%F").unwrap_or_else(|_| {
eprintln!("{}", &message);
process::exit(1);
});
if date < earliest {
earliest
} else if date > today {
today
} else {
date
}
};
(
parse_date(start_date).and_hms(0, 0, 1).timestamp() * 1000,
parse_date(end_date).and_hms(0, 0, 2).timestamp() * 1000,
)
}
make a function for checking and clamping
fn check_and_clamp_date() is needed for start_date and end_date
extern crate chrono;
use chrono::{Duration, NaiveDate, Utc};
static FMT: &str = "%Y-%m-%d";
fn check_and_clamp_date(date: &str, ceil_date: &str, floor_date: &str) -> Result<i64, String> {
let ceil = NaiveDate::parse_from_str(ceil_date, FMT).expect("start corrupt");
let floor = NaiveDate::parse_from_str(floor_date, FMT).expect("end corrupt");
let r = NaiveDate::parse_from_str(date, FMT);
match r {
Err(_) => Err("invalid date format".to_string()),
Ok(d) => {
let delta: Duration = d.clamp(ceil, floor).signed_duration_since(NaiveDate::from_ymd(1970, 1, 1));
Ok(delta.num_nanoseconds().unwrap() / 1000000)
},
}
}
fn check_and_transform_dates(start_date: &str, end_date: &str) -> (i64, i64) {
let message: String = format!(
"Data could not be downloaded ❌, please make sure your dates
are in the following format YYYY-MM-DD
(ie. 2020-01-01), your dates are Start Date: {}, End Date: {}",
&start_date, &end_date,
);
let now = Utc::now().format(FMT).to_string();
let start = check_and_clamp_date(start_date, "2016-01-01", &now).expect(&message) + 1000;
let end = check_and_clamp_date(end_date, "2016-01-01", &now).expect(&message) + 2000;
(start, end)
}
I'm currently using serde-hex.
use serde_hex::{SerHex,StrictPfx,CompactPfx};
#[derive(Debug,PartialEq,Eq,Serialize,Deserialize)]
struct Foo {
#[serde(with = "SerHex::<StrictPfx>")]
bar: [u8;4],
#[serde(with = "SerHex::<CompactPfx>")]
bin: u64
}
fn it_works() {
let foo = Foo { bar: [0,1,2,3], bin: 16 };
let ser = serde_json::to_string(&foo).unwrap();
let exp = r#"{"bar":"0x00010203","bin":"0x10"}"#;
assert_eq!(ser,exp);
// this fails
let binser = bincode::serialize(&foo).unwrap();
let binexp: [u8; 12] = [0, 1, 2, 3, 16, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(binser,binexp);
}
fails with:
thread 'wire::era::tests::it_works' panicked at 'assertion failed: `(left == right)`
left: `[10, 0, 0, 0, 0, 0, 0, 0, 48, 120, 48, 48, 48, 49, 48, 50, 48, 51, 4, 0, 0, 0, 0, 0, 0, 0, 48, 120, 49, 48]`,
right: `[0, 1, 2, 3, 16, 0, 0, 0, 0, 0, 0, 0]`', src/test.rs:20:9
because it has expanded values to hex strings for bincode.
I have many structs which I need to serialise with both serde_json and bincode. serde_hex does exactly what I need for JSON serialisation. When using bincode serde-hex still transforms arrays into hex strings, which is not wanted.
I notice that secp256k1 uses d.is_human_readable().
How can I make serde_hex apply only to serde_json and be ignored for bincode?
The implementation of a function usable with serde's with-attribute is mostly boilerplate and looks like this.
This only differentiates between human-readable and other formats. If you need more fine-grained control, you could branch on a thread-local variable instead.
fn serialize_hex<S>(v: &u64, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serde_hex::SerHex::<serde_hex::StrictPfx>::serialize(v, serializer)
} else {
v.serialize(serializer)
}
}
// use like
// #[serde(serialize_with = "serialize_hex")]
// bin: u64
The snippet could be improved by turning the u64 into a generic.
I am attempting to call Btrieve (a very old database engine) from Rust.
This is a bit long, but this is my first attempt at FFI from Rust and
I wanted to describe everything I have done.
The Btrieve engine is implemented in a DLL, w3btrv7.dll, which is a
32-bit DLL. I have made an import library for it using 32-bit MSVC tools
(it doesn't come with an official one):
lib /Def:w3btrv7.def /Out:w3btrv7.lib /Machine:x86
I then installed the 32-bit Rust toolchain stable-i686-pc-windows-msvc
and set it as my default. Bindgen barfs on the official Btrieve headers
so I had to make my own. Luckily we only need to wrap a single function,
BTRCALL.
I have this in my wrapper.h:
short int BTRCALL(
unsigned short operation,
void* posBlock,
void* dataBuffer,
unsigned short* dataLength,
void* keyBuffer,
unsigned char keyLength,
char ckeynum);
I am linking as:
println!("cargo:rustc-link-lib=./src/pervasive/w3btrv7");
Which seems to work: the program runs, is a 32-bit exe, and I can
see in Process Explorer that it has loaded w3btrv7.dll.
When I send the header through bindgen I get:
extern "C" {
pub fn BTRCALL(
operation: ::std::os::raw::c_ushort,
posBlock: *mut ::std::os::raw::c_void,
dataBuffer: *mut ::std::os::raw::c_void,
dataLength: *mut ::std::os::raw::c_ushort,
keyBuffer: *mut ::std::os::raw::c_void,
keyLength: ::std::os::raw::c_uchar,
ckeynum: ::std::os::raw::c_char,
) -> ::std::os::raw::c_short;
}
The types and sizes all seem to tally up correctly, and they match
a DllImport I have from a C# application which works perfectly:
[DllImport("w3btrv7.dll", CharSet = CharSet.Ansi)]
private static extern short BTRCALL(
ushort operation, // In C#, ushort = UInt16.
[MarshalAs(UnmanagedType.LPArray, SizeConst = 128)] byte[] posBlock,
[MarshalAs(UnmanagedType.LPArray)] byte[] dataBuffer,
ref ushort dataLength,
[MarshalAs(UnmanagedType.LPArray)] byte[] keyBuffer,
byte keyLength, // unsigned byte
char keyNumber); // 2 byte char
The keyNumber is slightly different, but I have tried both bytes and shorts in both signed and unsigned variations, and it still doesn't work.
Unfortunately when I run my program it blows up after the first call
to BTRCALL. (Well, actually it's when the function that this call is in
returns). I've extracted all the params into local variables and checked
their types and all looks correct:
let op: u16 = 0;
let mut pos_block: [u8; 128] = self.pos_block.clone();
let pos_block_ptr: *mut std::ffi::c_void = pos_block.as_mut_ptr() as *mut _;
let mut data_buffer: [u8; 32768] = self.data_buffer.clone();
let data_buffer_ptr: *mut std::ffi::c_void = data_buffer.as_mut_ptr() as *mut _;
let mut data_length: u16 = data_buffer.len() as u16;
let mut key_buffer: [u8; 256] = self.key_buffer.clone();
let key_buffer_ptr: *mut std::ffi::c_void = key_buffer.as_mut_ptr() as *mut _;
let key_length: u8 = 255; //self.key_length;
let key_number: i8 = self.key_number.try_into().unwrap();
let status: i16 = BTRCALL(
op,
pos_block_ptr,
data_buffer_ptr,
&mut data_length,
key_buffer_ptr,
key_length,
key_number
);
It crashes the program with
error: process didn't exit successfully: `target\debug\blah.exe` (exit code: 0xc0000005, STATUS_ACCESS_VIOLATION)
From what I have read, this is probably due to an improper address access.
Indeed, when I put some tracing in to check the variables there is some very interesting behaviour, in that my
local variables which are passed by value seem to be getting overwritten. The log here just dumps the first
30 bytes of the buffers because the rest is just zeros:
pos_block = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pos_block_ptr = 0xad6524
data_buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
data_buffer_ptr = 0xad65a8
data_length = 32768
key_buffer = [34, 67, 58, 92, 116, 101, 109, 112, 92, 99, 115, 115, 92, 120, 100, 98, 92, 67, 65, 83, 69, 46, 68, 66, 34, 0, 0, 0, 0, 0]
key_buffer_ptr = 0xade5b0
key_length = 255
key_number = 0
>>>>>>>>>>>>>>> AFTER THE CALL TO BTRCALL:
pos_block = [0, 0, 0, 0, 0, 0, 0, 0, 0, 76, 203, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0]
pos_block_ptr = 0x0
data_buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
data_buffer_ptr = 0x42442e45
data_length = 0
key_buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
key_buffer_ptr = 0x0
key_length = 173
key_number = 0
BTRCALL() returned B_NO_ERROR
Notice pos_block_ptr has been set to 0, among other things. In contrast, a successful execution
of the exact same call from the C# code simply writes some data into the first 18 bytes of pos_block
and doesn't change any of the other variables.
It's as if it went a bit berserk and just started overwriting memory...
At this point I don't know what to try next.
Changing the declaration from extern "C" to extern "stdcall" works:
extern "stdcall" {
pub fn BTRCALL(
operation: ::std::os::raw::c_ushort,
posBlock: *mut ::std::os::raw::c_void,
dataBuffer: *mut ::std::os::raw::c_void,
dataLength: *mut ::std::os::raw::c_ushort,
keyBuffer: *mut ::std::os::raw::c_void,
keyLength: ::std::os::raw::c_uchar,
ckeynum: ::std::os::raw::c_char,
) -> ::std::os::raw::c_short;
}