trillium_static_compiled_macros/
lib.rs1use proc_macro::{TokenStream, TokenTree};
19use proc_macro2::Literal;
20use quote::quote;
21use std::{
22 error::Error,
23 fmt::{self, Display, Formatter},
24 path::{Path, PathBuf},
25 time::SystemTime,
26};
27
28#[proc_macro]
30pub fn include_dir(input: TokenStream) -> TokenStream {
31 let tokens: Vec<_> = input.into_iter().collect();
32
33 let path = match tokens.as_slice() {
34 [TokenTree::Literal(lit)] => unwrap_string_literal(lit),
35 _ => panic!("This macro only accepts a single, non-empty string argument"),
36 };
37
38 let path = resolve_path(&path, get_env)
39 .unwrap()
40 .canonicalize()
41 .unwrap();
42
43 expand_dir(&path, &path).into()
44}
45
46#[proc_macro]
48pub fn include_entry(input: TokenStream) -> TokenStream {
49 let tokens: Vec<_> = input.into_iter().collect();
50
51 let path = match tokens.as_slice() {
52 [TokenTree::Literal(lit)] => unwrap_string_literal(lit),
53 _ => panic!("This macro only accepts a single, non-empty string argument"),
54 };
55
56 let path = resolve_path(&path, get_env)
57 .unwrap()
58 .canonicalize()
59 .unwrap();
60
61 expand_entry(&path, &path).into()
62}
63
64fn unwrap_string_literal(lit: &proc_macro::Literal) -> String {
65 let mut repr = lit.to_string();
66 if !repr.starts_with('"') || !repr.ends_with('"') {
67 panic!("This macro only accepts a single, non-empty string argument")
68 }
69
70 repr.remove(0);
71 repr.pop();
72
73 repr
74}
75
76fn expand_entry(root: &Path, child: &Path) -> proc_macro2::TokenStream {
77 if child.is_dir() {
78 let tokens = expand_dir(root, child);
79 quote!(DirEntry::Dir(#tokens))
80 } else if child.is_file() {
81 let tokens = expand_file(root, child);
82 quote!(DirEntry::File(#tokens))
83 } else {
84 panic!("\"{}\" is neither a file nor a directory", child.display());
85 }
86}
87
88fn expand_dir(root: &Path, path: &Path) -> proc_macro2::TokenStream {
89 let children = read_dir(path).unwrap_or_else(|e| {
90 panic!(
91 "Unable to read the entries in \"{}\": {}",
92 path.display(),
93 e
94 )
95 });
96
97 let child_tokens = children
98 .iter()
99 .map(|child| expand_entry(root, child))
100 .collect::<Vec<_>>();
101
102 let path = normalize_path(root, path);
103
104 quote!(Dir::new(#path, &[ #(#child_tokens),* ]))
105}
106
107fn expand_file(root: &Path, path: &Path) -> proc_macro2::TokenStream {
108 let contents = read_file(path);
109 let literal = Literal::byte_string(&contents);
110
111 let normalized_path = if root == path {
115 path.file_name()
116 .map(|n| n.to_string_lossy().into_owned())
117 .unwrap_or_default()
118 } else {
119 normalize_path(root, path)
120 };
121
122 let tokens = quote!(File::new(#normalized_path, #literal));
123
124 match metadata(path) {
125 Some(metadata) => quote!(#tokens.with_metadata(#metadata)),
126 None => tokens,
127 }
128}
129
130fn metadata(path: &Path) -> Option<proc_macro2::TokenStream> {
131 fn to_unix(t: SystemTime) -> u64 {
132 t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
133 }
134
135 let meta = path.metadata().ok()?;
136 let accessed = meta.accessed().map(to_unix).ok()?;
137 let created = meta.created().map(to_unix).ok()?;
138 let modified = meta.modified().map(to_unix).ok()?;
139
140 Some(quote!(Metadata::from_secs(#accessed, #created, #modified)))
141}
142
143fn normalize_path(root: &Path, path: &Path) -> String {
146 let stripped = path
147 .strip_prefix(root)
148 .expect("Should only ever be called using paths inside the root path");
149 let as_string = stripped.to_string_lossy();
150
151 as_string.replace('\\', "/")
152}
153
154fn read_dir(dir: &Path) -> Result<Vec<PathBuf>, Box<dyn Error>> {
155 if !dir.is_dir() {
156 panic!("\"{}\" is not a directory", dir.display());
157 }
158
159 let mut paths = Vec::new();
160
161 for entry in dir.read_dir()? {
162 let entry = entry?;
163 paths.push(entry.path());
164 }
165
166 paths.sort();
167
168 Ok(paths)
169}
170
171fn read_file(path: &Path) -> Vec<u8> {
172 std::fs::read(path).unwrap_or_else(|e| panic!("Unable to read \"{}\": {}", path.display(), e))
173}
174
175fn resolve_path(
176 raw: &str,
177 get_env: impl Fn(&str) -> Option<String>,
178) -> Result<PathBuf, Box<dyn Error>> {
179 let mut unprocessed = raw;
180 let mut resolved = String::new();
181
182 while let Some(dollar_sign) = unprocessed.find('$') {
183 let (head, tail) = unprocessed.split_at(dollar_sign);
184 resolved.push_str(head);
185
186 match parse_identifier(&tail[1..]) {
187 Some((variable, rest)) => {
188 let value = get_env(variable).ok_or_else(|| MissingVariable {
189 variable: variable.to_string(),
190 })?;
191 resolved.push_str(&value);
192 unprocessed = rest;
193 }
194 None => {
195 return Err(UnableToParseVariable { rest: tail.into() }.into());
196 }
197 }
198 }
199 resolved.push_str(unprocessed);
200
201 let path = PathBuf::from(resolved);
202 if path.is_relative() {
203 Ok(PathBuf::from(
204 get_env("CARGO_MANIFEST_DIR").ok_or_else(|| MissingVariable {
205 variable: "CARGO_MANIFEST_DIR".to_string(),
206 })?,
207 )
208 .join(path))
209 } else {
210 Ok(path)
211 }
212}
213
214#[derive(Debug, PartialEq)]
215struct MissingVariable {
216 variable: String,
217}
218
219impl Error for MissingVariable {}
220
221impl Display for MissingVariable {
222 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
223 write!(f, "Unable to resolve ${}", self.variable)
224 }
225}
226
227#[derive(Debug, PartialEq)]
228struct UnableToParseVariable {
229 rest: String,
230}
231
232impl Error for UnableToParseVariable {}
233
234impl Display for UnableToParseVariable {
235 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
236 write!(f, "Unable to parse a variable from \"{}\"", self.rest)
237 }
238}
239
240fn parse_identifier(text: &str) -> Option<(&str, &str)> {
241 let mut calls = 0;
242
243 let (head, tail) = take_while(text, |c| {
244 calls += 1;
245
246 match c {
247 '_' => true,
248 letter if letter.is_ascii_alphabetic() => true,
249 digit if digit.is_ascii_digit() && calls > 1 => true,
250 _ => false,
251 }
252 });
253
254 if head.is_empty() {
255 None
256 } else {
257 Some((head, tail))
258 }
259}
260
261fn take_while(s: &str, mut predicate: impl FnMut(char) -> bool) -> (&str, &str) {
262 let mut index = 0;
263
264 for c in s.chars() {
265 if predicate(c) {
266 index += c.len_utf8();
267 } else {
268 break;
269 }
270 }
271
272 s.split_at(index)
273}
274
275fn get_env(variable: &str) -> Option<String> {
276 std::env::var(variable).ok()
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn resolve_path_with_no_environment_variables() {
285 let path = "./file.txt";
286
287 let resolved = resolve_path(path, |name| {
288 assert_eq!(name, "CARGO_MANIFEST_DIR");
289 Some("/files/cargo_manifest_dir".to_string())
290 })
291 .unwrap();
292
293 assert_eq!(
294 resolved.to_str().unwrap(),
295 PathBuf::from("/files/cargo_manifest_dir")
296 .join("./file.txt")
297 .to_str()
298 .unwrap()
299 );
300 }
301
302 #[test]
303 fn simple_environment_variable() {
304 let path = "../$VAR";
305
306 let resolved = resolve_path(path, |name| match name {
307 "VAR" => Some("file.txt".to_string()),
308 "CARGO_MANIFEST_DIR" => Some("/files/cargo_manifest_dir".to_string()),
309 _ => unreachable!(),
310 })
311 .unwrap();
312
313 assert_eq!(
314 resolved.to_str().unwrap(),
315 PathBuf::from("/files/cargo_manifest_dir")
316 .join("../file.txt")
317 .to_str()
318 .unwrap()
319 );
320 }
321
322 #[test]
323 fn dont_resolve_recursively() {
324 let path = "./$TOP_LEVEL.txt";
325
326 let resolved = resolve_path(path, |name| match name {
327 "TOP_LEVEL" => Some("$NESTED".to_string()),
328 "CARGO_MANIFEST_DIR" => Some("/files/cargo_manifest_dir".to_string()),
329 "$NESTED" => unreachable!("Shouln't resolve recursively"),
330 _ => unreachable!(),
331 })
332 .unwrap();
333
334 assert_eq!(
335 resolved.to_str().unwrap(),
336 PathBuf::from("/files/cargo_manifest_dir")
337 .join("./$NESTED.txt")
338 .to_str()
339 .unwrap()
340 );
341 }
342
343 #[test]
344 fn parse_valid_identifiers() {
345 let inputs = vec![
346 ("a", "a"),
347 ("a_", "a_"),
348 ("_asf", "_asf"),
349 ("a1", "a1"),
350 ("a1_#sd", "a1_"),
351 ];
352
353 for (src, expected) in inputs {
354 let (got, rest) = parse_identifier(src).unwrap();
355 assert_eq!(got.len() + rest.len(), src.len());
356 assert_eq!(got, expected);
357 }
358 }
359
360 #[test]
361 fn unknown_environment_variable() {
362 let path = "$UNKNOWN";
363
364 let err = resolve_path(path, |_| None).unwrap_err();
365
366 let missing_variable = err.downcast::<MissingVariable>().unwrap();
367 assert_eq!(
368 *missing_variable,
369 MissingVariable {
370 variable: String::from("UNKNOWN"),
371 }
372 );
373 }
374
375 #[test]
376 fn invalid_variables() {
377 let inputs = &["$1", "$"];
378
379 for input in inputs {
380 let err = resolve_path(input, |_| unreachable!()).unwrap_err();
381
382 let err = err.downcast::<UnableToParseVariable>().unwrap();
383 assert_eq!(
384 *err,
385 UnableToParseVariable {
386 rest: input.to_string(),
387 }
388 );
389 }
390 }
391}