1use opentelemetry::{
2 Array, Context, KeyValue, Value,
3 trace::{SpanBuilder, SpanKind, TraceContextExt, Tracer},
4};
5use std::{
6 borrow::Cow,
7 fmt::{self, Debug, Formatter},
8 net::SocketAddr,
9 sync::Arc,
10 time::{Instant, SystemTime},
11};
12use trillium::{Conn, Handler, HeaderName, KnownHeaderName, Status, Transport};
13
14type StringExtractionFn = dyn Fn(&Conn) -> Option<Cow<'static, str>> + Send + Sync + 'static;
15
16#[derive(Clone)]
20pub struct Trace<T> {
21 pub(crate) route: Option<Arc<StringExtractionFn>>,
22 pub(crate) error_type: Option<Arc<StringExtractionFn>>,
23 pub(crate) headers: Vec<HeaderName<'static>>,
24 pub(crate) enable_local_address_and_port: bool,
25 tracer: T,
26 socket_addr: Option<SocketAddr>,
27}
28
29impl<Span> Debug for Trace<Span> {
30 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
31 f.debug_struct("Trace")
32 .field(
33 "route",
34 &match self.route {
35 Some(_) => "Some(..)",
36 _ => "None",
37 },
38 )
39 .field(
40 "error_type",
41 &match self.error_type {
42 Some(_) => "Some(..)",
43 _ => "None",
44 },
45 )
46 .field("tracer", &"..")
47 .finish()
48 }
49}
50
51pub fn trace<T: Tracer>(tracer: T) -> Trace<T> {
53 Trace::new(tracer)
54}
55
56impl<T: Tracer> Trace<T> {
57 pub fn new(tracer: T) -> Self {
59 Trace {
60 route: None,
61 error_type: None,
62 enable_local_address_and_port: false,
63 tracer,
64 headers: vec![],
65 socket_addr: None,
66 }
67 }
68
69 pub fn with_route<F>(mut self, route: F) -> Self
81 where
82 F: Fn(&Conn) -> Option<Cow<'static, str>> + Send + Sync + 'static,
83 {
84 self.route = Some(Arc::new(route));
85 self
86 }
87
88 pub fn with_error_type<F>(mut self, error_type: F) -> Self
93 where
94 F: Fn(&Conn) -> Option<Cow<'static, str>> + Send + Sync + 'static,
95 {
96 self.error_type = Some(Arc::new(error_type));
97 self
98 }
99
100 pub fn with_headers(
102 mut self,
103 headers: impl IntoIterator<Item = impl Into<HeaderName<'static>>>,
104 ) -> Self {
105 self.headers = headers.into_iter().map(Into::into).collect();
106 self
107 }
108
109 pub fn with_local_address_and_port(mut self) -> Self {
113 self.enable_local_address_and_port = true;
114 self
115 }
116}
117
118#[derive(Clone, Debug)]
119pub(crate) struct TraceContext {
120 pub(crate) context: Context,
121}
122
123struct RouteWasAvailable;
124
125impl<T> Handler for Trace<T>
126where
127 T: Tracer + Send + Sync + 'static,
128 T::Span: Send + Sync + 'static,
129{
130 async fn init(&mut self, info: &mut trillium::Info) {
131 if self.enable_local_address_and_port {
132 self.socket_addr = info.tcp_socket_addr().cloned();
133 }
134 }
135 async fn run(&self, mut conn: Conn) -> Conn {
136 let start_time = Some(SystemTime::now() - conn.start_time().duration_since(Instant::now()));
137
138 let scheme = if conn.is_secure() { "https" } else { "http" };
139 let method = conn.method().as_str();
140
141 let version = conn.http_version().as_str().strip_prefix("HTTP/").unwrap();
142
143 let path_and_query = conn.path_and_query();
144 let (path, query) = match path_and_query.find('?') {
145 Some(x) => (&path_and_query[0..x], &path_and_query[x + 1..]),
146 None => (path_and_query, ""),
147 };
148
149 let mut attributes = vec![
150 KeyValue::new("http.request.method", method),
151 KeyValue::new("url.path", path.to_string()),
152 KeyValue::new("url.scheme", scheme),
153 KeyValue::new("url.query", query.to_string()),
154 KeyValue::new("network.protocol.name", "http"),
155 KeyValue::new("network.protocol.version", version),
156 ];
157
158 if let Some(socket_addr) = &self.socket_addr {
159 attributes.push(KeyValue::new(
160 "network.local.address",
161 socket_addr.ip().to_string(),
162 ));
163
164 attributes.push(KeyValue::new(
165 "network.local.port",
166 i64::from(socket_addr.port()),
167 ));
168 }
169
170 if let Some(peer_ip) = conn.peer_ip() {
171 attributes.push(KeyValue::new("client.address", peer_ip.to_string()));
172 }
173
174 for (header_name, header_values) in self.headers.iter().filter_map(|hn| {
175 conn.request_headers()
176 .get_values(hn.clone())
177 .map(|v| (hn, v))
178 }) {
179 attributes.push(KeyValue::new(
180 format!(
181 "http.request.header.{}",
182 header_name.as_ref().to_lowercase()
183 ),
184 Value::Array(Array::String(
185 header_values.iter().map(|x| x.to_string().into()).collect(),
186 )),
187 ));
188 }
189
190 let address_and_port = conn.host().map(|host| {
191 host.split_once(':')
192 .and_then(|(host, port)| Some((String::from(host), port.parse().ok()?)))
193 .unwrap_or_else(|| (String::from(host), if conn.is_secure() { 443 } else { 80 }))
194 });
195
196 if let Some((address, port)) = address_and_port {
197 attributes.push(KeyValue::new("server.address", address));
198 attributes.push(KeyValue::new("server.port", port));
199 }
200
201 if let Some(user_agent) = conn.request_headers().get_str(KnownHeaderName::UserAgent) {
202 attributes.push(KeyValue::new("user_agent.original", user_agent.to_string()));
203 }
204
205 let name = if let Some(route) = self.route.as_ref().and_then(|route| route(&conn)) {
206 conn.insert_state(RouteWasAvailable);
207 attributes.push(KeyValue::new("http.route", route.clone()));
208 format!("{} {route}", conn.method().as_str()).into()
209 } else {
210 conn.method().as_str().into()
211 };
212
213 let span = self.tracer.build(SpanBuilder {
214 name,
215 start_time,
216 span_kind: Some(SpanKind::Server),
217 attributes: Some(attributes),
218 ..SpanBuilder::default()
219 });
220 let context = Context::current_with_span(span);
221
222 conn.with_state(TraceContext { context })
223 }
224
225 async fn before_send(&self, mut conn: Conn) -> Conn {
226 let Some(TraceContext { context }) = conn.state().cloned() else {
227 return conn;
228 };
229
230 let span = context.span();
231
232 let error_type = self
233 .error_type
234 .as_ref()
235 .and_then(|et| et(&conn))
236 .or_else(|| {
237 let status = conn.status().unwrap_or(Status::NotFound);
238 if status.is_server_error() {
239 Some((status as u16).to_string().into())
240 } else {
241 None
242 }
243 });
244
245 if conn.status().is_some_and(|s| s.is_server_error()) {
246 span.set_status(opentelemetry::trace::Status::Error {
247 description: "".into(), });
249 }
250
251 let status: i64 = (conn.status().unwrap_or(Status::NotFound) as u16).into();
252
253 let mut attributes = vec![KeyValue::new("http.response.status_code", status)];
254
255 if conn.take_state::<RouteWasAvailable>().is_none() {
256 let route = self.route.as_ref().and_then(|route| route(&conn));
257 if let Some(route) = &route {
258 attributes.push(KeyValue::new("http.route", route.clone()));
259 span.update_name(format!("{} {route}", conn.method().as_str()));
260 }
261 }
262
263 if let Some(error_type) = error_type {
264 attributes.push(KeyValue::new("error.type", error_type));
265 }
266
267 span.set_attributes(attributes);
268
269 {
270 let context = context.clone();
271 let inner: &mut trillium_http::Conn<Box<dyn Transport>> = conn.as_mut();
272 inner.after_send(move |send_status| {
273 let span = context.span();
274 if !send_status.is_success() {
275 span.set_status(opentelemetry::trace::Status::Error {
276 description: "http send error".into(),
277 });
278 span.set_attribute(KeyValue::new("error.type", "http send error"));
279 }
280 span.end();
281 });
282 }
283
284 conn
285 }
286}