use proc_macro::TokenStream;
use quote::quote;
use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token::{Comma, Where},
visit::{visit_type_path, Visit},
Data, DeriveInput, Error, Field, Ident, Index, Member, Type, TypePath, WhereClause,
};
fn is_required_generic_for_type(ty: &Type, generic: &Ident) -> bool {
struct PathVisitor<'g> {
generic: &'g Ident,
generic_is_required: bool,
}
impl<'g, 'ast> Visit<'ast> for PathVisitor<'g> {
fn visit_type_path(&mut self, node: &'ast TypePath) {
if node.qself.is_none() {
if let Some(first_segment) = node.path.segments.first() {
if first_segment.ident == *self.generic {
self.generic_is_required = true;
}
}
}
visit_type_path(self, node);
}
}
let mut path_visitor = PathVisitor {
generic,
generic_is_required: false,
};
path_visitor.visit_type(ty);
path_visitor.generic_is_required
}
struct DeriveOptions {
input: DeriveInput,
field: Field,
field_index: usize,
}
fn generics(field: &Field, input: &DeriveInput) -> Vec<Ident> {
input
.generics
.type_params()
.filter_map(|g| {
if is_required_generic_for_type(&field.ty, &g.ident) {
Some(g.ident.clone())
} else {
None
}
})
.collect::<HashSet<_>>()
.into_iter()
.collect()
}
impl Parse for DeriveOptions {
fn parse(input: ParseStream) -> syn::Result<Self> {
let input = DeriveInput::parse(input)?;
let Data::Struct(ds) = &input.data else {
return Err(Error::new(input.span(), "second error"));
};
for (field_index, field) in ds.fields.iter().enumerate() {
for attr in &field.attrs {
if attr.path().is_ident("async_write") || attr.path().is_ident("async_io") {
let field = field.clone();
return Ok(Self {
input,
field,
field_index,
});
}
}
}
if ds.fields.len() == 1 {
let field = ds
.fields
.iter()
.next()
.expect("len == 1 should have one element")
.clone();
Ok(Self {
input,
field,
field_index: 0,
})
} else {
Err(Error::new(
input.span(),
"Structs with more than one field need an #[async_io] or #[async_write] annotation",
))
}
}
}
pub fn derive_async_write(input: TokenStream) -> TokenStream {
let DeriveOptions {
field,
input,
field_index,
} = parse_macro_input!(input as DeriveOptions);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let generics = generics(&field, &input);
let struct_name = input.ident;
let mut where_clause = where_clause.map_or_else(
|| WhereClause {
where_token: Where::default(),
predicates: Punctuated::new(),
},
|where_clause| where_clause.to_owned(),
);
for generic in generics {
where_clause
.predicates
.push_value(parse_quote! { #generic: AsyncWrite + Unpin });
where_clause.predicates.push_punct(Comma::default());
}
where_clause
.predicates
.push_value(parse_quote! { Self: Unpin });
let handler = field
.ident
.map_or_else(|| Member::Unnamed(Index::from(field_index)), Member::Named);
let handler = quote!(self.#handler);
quote! {
impl #impl_generics AsyncWrite for #struct_name #ty_generics #where_clause {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut #handler).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut #handler).poll_flush(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut #handler).poll_close(cx)
}
fn poll_write_vectored(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>]
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut #handler).poll_write_vectored(cx, bufs)
}
}
}
.into()
}