Skip to content

Commit

Permalink
Rework context passing.
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciech-graj committed May 18, 2024
1 parent a97196c commit b50116b
Show file tree
Hide file tree
Showing 32 changed files with 462 additions and 313 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- Delete `EnumExt`
- Bump rust version to 2021
- Make lifetime generics work
- Handle context using generics instead of `Any`
# v0.3.4
- Do not trigger https://github.com/rust-lang/rust/issues/120363 with generated code
# v0.3.3
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ members = [
exclude = [
"bench",
]
resolver = "2"
2 changes: 1 addition & 1 deletion bin-proto-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ keywords = ["protocol", "binary", "bit", "codec", "serde"]
proc-macro = true

[dependencies]
syn = { version = "1.0.109", features = ["default", "extra-traits", "fold"] }
syn = { version = "1.0.109", features = ["default", "extra-traits", "parsing"] }
quote = "1.0.35"
proc-macro2 = "1.0.79"
48 changes: 47 additions & 1 deletion bin-proto-derive/src/attr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use proc_macro2::TokenStream;
use syn::{parse::Parser, punctuated::Punctuated, token::Add, TypeParamBound};

#[derive(Debug, Default)]
pub struct Attrs {
pub discriminant_type: Option<syn::Ident>,
pub discriminant_type: Option<syn::Type>,
pub discriminant: Option<syn::Expr>,
pub ctx: Option<syn::Type>,
pub ctx_bounds: Option<Punctuated<TypeParamBound, Add>>,
pub write_value: Option<syn::Expr>,
pub bits: Option<u32>,
pub flexible_array_member: bool,
Expand All @@ -16,6 +21,9 @@ impl Attrs {
if self.discriminant.is_some() {
panic!("unexpected discriminant attribute for enum")
}
if self.ctx.is_some() && self.ctx_bounds.is_some() {
panic!("cannot specify ctx and ctx_bounds simultaneously")
}
if self.write_value.is_some() {
panic!("unexpected write_value attribute for enum")
}
Expand All @@ -31,6 +39,12 @@ impl Attrs {
if self.discriminant_type.is_some() {
panic!("unexpected discriminant_type attribute for variant")
}
if self.ctx.is_some() {
panic!("unexpected ctx attribute for variant")
}
if self.ctx_bounds.is_some() {
panic!("unexpected ctx_bounds attribute for variant")
}
if self.write_value.is_some() {
panic!("unexpected write_value attribute for variant")
}
Expand All @@ -52,6 +66,12 @@ impl Attrs {
if self.discriminant.is_some() {
panic!("unexpected discriminant attribute for field")
}
if self.ctx.is_some() {
panic!("unexpected ctx attribute for variant")
}
if self.ctx_bounds.is_some() {
panic!("unexpected ctx_bounds attribute for variant")
}
if [
self.bits.is_some(),
self.flexible_array_member,
Expand All @@ -65,6 +85,13 @@ impl Attrs {
panic!("bits, flexible_array_member, and length are mutually-exclusive attributes")
}
}

pub fn ctx_tok(&self) -> TokenStream {
self.ctx
.clone()
.map(|ctx| quote!(#ctx))
.unwrap_or(quote!(__Ctx))
}
}

impl From<&[syn::Attribute]> for Attrs {
Expand Down Expand Up @@ -97,6 +124,11 @@ impl From<&[syn::Attribute]> for Attrs {
attribs.discriminant =
Some(meta_name_value_to_parse(name_value))
}
"ctx" => attribs.ctx = Some(meta_name_value_to_parse(name_value)),
"ctx_bounds" => {
attribs.ctx_bounds =
Some(meta_name_value_to_punctuated(name_value))
}
"bits" => attribs.bits = Some(meta_name_value_to_u32(name_value)),
"write_value" => {
attribs.write_value = Some(meta_name_value_to_parse(name_value))
Expand Down Expand Up @@ -149,3 +181,17 @@ fn meta_name_value_to_u32(name_value: syn::MetaNameValue) -> u32 {
_ => panic!("bitfield size must be an integer"),
}
}

fn meta_name_value_to_punctuated<T: syn::parse::Parse, P: syn::parse::Parse>(
name_value: syn::MetaNameValue,
) -> Punctuated<T, P> {
match name_value.lit {
syn::Lit::Str(s) => match Punctuated::parse_terminated.parse_str(s.value().as_str()) {
Ok(f) => f,
Err(_) => {
panic!("Failed to parse '{}'", s.value())
}
},
_ => panic!("#[protocol(... = \"...\")] must be string"),
}
}
40 changes: 14 additions & 26 deletions bin-proto-derive/src/codegen/enums.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
attr::Attrs,
codegen,
plan::{self, EnumVariant},
};
Expand All @@ -19,8 +20,7 @@ pub fn write_variant(

let write_discriminant = write_discriminant(variant);

let (binding_names, fields_pattern) =
bind_fields_pattern(variant_name, &variant.fields);
let fields_pattern = bind_fields_pattern(variant_name, &variant.fields);

let writes = codegen::writes(&variant.fields, false);

Expand All @@ -39,13 +39,17 @@ pub fn write_variant(
)
}

pub fn read_variant(plan: &plan::Enum, read_discriminant: TokenStream) -> TokenStream {
pub fn read_variant(
plan: &plan::Enum,
read_discriminant: TokenStream,
attribs: &Attrs,
) -> TokenStream {
let discriminant_ty = plan.discriminant_ty.clone();

let discriminant_match_branches = plan.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let discriminant_literal = variant.discriminant_value.clone();
let (reader, initializer) = codegen::reads(&variant.fields);
let (reader, initializer) = codegen::reads(&variant.fields, attribs);

quote!(
#discriminant_literal => {
Expand Down Expand Up @@ -74,28 +78,16 @@ pub fn read_variant(plan: &plan::Enum, read_discriminant: TokenStream) -> TokenS
/// Generates code for a pattern that binds a set of fields by reference.
///
/// Returns a tuple of the pattern tokens and the field binding names.
pub fn bind_fields_pattern(
parent_name: &syn::Ident,
fields: &syn::Fields,
) -> (Vec<syn::Ident>, TokenStream) {
pub fn bind_fields_pattern(parent_name: &syn::Ident, fields: &syn::Fields) -> TokenStream {
match *fields {
syn::Fields::Named(ref fields_named) => {
let field_names: Vec<_> = fields_named
.named
.iter()
.map(|f| f.ident.clone().unwrap())
.collect();
let field_name_refs = fields_named
.named
.iter()
.map(|f| &f.ident)
.map(|n| quote!( ref #n ));

(
field_names,
quote!(
#parent_name { #( #field_name_refs ),* }
),
quote!(
#parent_name { #( #field_name_refs ),* }
)
}
syn::Fields::Unnamed(ref fields_unnamed) => {
Expand All @@ -104,14 +96,10 @@ pub fn bind_fields_pattern(
.collect();

let field_refs: Vec<_> = binding_names.iter().map(|i| quote!( ref #i )).collect();

(
binding_names,
quote!(
#parent_name ( #( #field_refs ),* )
),
quote!(
#parent_name ( #( #field_refs ),* )
)
}
syn::Fields::Unit => (Vec::new(), quote!(#parent_name)),
syn::Fields::Unit => quote!(#parent_name),
}
}
32 changes: 20 additions & 12 deletions bin-proto-derive/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ pub mod enums;

use crate::attr::Attrs;
use proc_macro2::TokenStream;
use syn::{fold::Fold, Expr, ExprField};

pub fn reads(fields: &syn::Fields) -> (TokenStream, TokenStream) {
pub fn reads(fields: &syn::Fields, attrs: &Attrs) -> (TokenStream, TokenStream) {
match *fields {
syn::Fields::Named(ref fields_named) => read_named_fields(fields_named),
syn::Fields::Unnamed(ref fields_unnamed) => (quote!(), read_unnamed_fields(fields_unnamed)),
syn::Fields::Named(ref fields_named) => read_named_fields(fields_named, attrs),
syn::Fields::Unnamed(ref fields_unnamed) => {
(quote!(), read_unnamed_fields(fields_unnamed, attrs))
}
syn::Fields::Unit => (quote!(), quote!()),
}
}
Expand All @@ -22,15 +23,15 @@ pub fn writes(fields: &syn::Fields, self_prefix: bool) -> TokenStream {
}
}

fn read_named_fields(fields_named: &syn::FieldsNamed) -> (TokenStream, TokenStream) {
fn read_named_fields(fields_named: &syn::FieldsNamed, attrs: &Attrs) -> (TokenStream, TokenStream) {
let fields: Vec<_> = fields_named
.named
.iter()
.map(|field| {
let field_name = &field.ident;
let field_ty = &field.ty;

let read = read(field);
let read = read(field, attrs);

quote!(
let #field_name : #field_ty = #read?;
Expand All @@ -56,12 +57,15 @@ fn read_named_fields(fields_named: &syn::FieldsNamed) -> (TokenStream, TokenStre
)
}

fn read(field: &syn::Field) -> TokenStream {
fn read(field: &syn::Field, parent_attribs: &Attrs) -> TokenStream {
let attribs = Attrs::from(field.attrs.as_slice());
attribs.validate_field();

let ctx_ty = parent_attribs.ctx_tok();

if let Some(field_width) = attribs.bits {
quote!(
bin_proto::BitField::read(__io_reader, __byte_order, __ctx, #field_width)
bin_proto::BitField::<#ctx_ty>::read(__io_reader, __byte_order, __ctx, #field_width)
)
} else if attribs.flexible_array_member {
quote!(bin_proto::FlexibleArrayMember::read(
Expand All @@ -71,10 +75,14 @@ fn read(field: &syn::Field) -> TokenStream {
))
} else if let Some(length) = attribs.length {
quote!(
bin_proto::ExternallyLengthPrefixed::read(__io_reader, __byte_order, __ctx, #length)
bin_proto::ExternallyLengthPrefixed::<#ctx_ty>::read(__io_reader, __byte_order, __ctx, #length)
)
} else {
quote!(bin_proto::Protocol::read(__io_reader, __byte_order, __ctx))
quote!(bin_proto::Protocol::<#ctx_ty>::read(
__io_reader,
__byte_order,
__ctx
))
}
}

Expand Down Expand Up @@ -138,13 +146,13 @@ fn write_named_fields(fields_named: &syn::FieldsNamed, self_prefix: bool) -> Tok
quote!( #( #field_writers );* )
}

fn read_unnamed_fields(fields_unnamed: &syn::FieldsUnnamed) -> TokenStream {
fn read_unnamed_fields(fields_unnamed: &syn::FieldsUnnamed, attrs: &Attrs) -> TokenStream {
let field_initializers: Vec<_> = fields_unnamed
.unnamed
.iter()
.map(|field| {
let field_ty = &field.ty;
let read = read(field);
let read = read(field, attrs);

quote!(
{
Expand Down
Loading

0 comments on commit b50116b

Please sign in to comment.