Include login in connection-related tracing spans/events

Nathan Sobo created

Also, include metadata on more events and add an event called "signing out" with all this metadata to make it easier to search for.

Change summary

crates/collab/src/auth.rs |  9 +++++++--
crates/collab/src/rpc.rs  | 33 ++++++++++++++++++++-------------
2 files changed, 27 insertions(+), 15 deletions(-)

Detailed changes

crates/collab/src/auth.rs 🔗

@@ -2,7 +2,7 @@ use std::sync::Arc;
 
 use super::db::{self, UserId};
 use crate::{AppState, Error};
-use anyhow::{Context, Result};
+use anyhow::{anyhow, Context, Result};
 use axum::{
     http::{self, Request, StatusCode},
     middleware::Next,
@@ -51,7 +51,12 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
     }
 
     if credentials_valid {
-        req.extensions_mut().insert(user_id);
+        let user = state
+            .db
+            .get_user_by_id(user_id)
+            .await?
+            .ok_or_else(|| anyhow!("user {} not found", user_id))?;
+        req.extensions_mut().insert(user);
         Ok::<_, Error>(next.run(req).await)
     } else {
         Err(Error::Http(

crates/collab/src/rpc.rs 🔗

@@ -2,7 +2,7 @@ mod store;
 
 use crate::{
     auth,
-    db::{self, ChannelId, MessageId, UserId},
+    db::{self, ChannelId, MessageId, User, UserId},
     AppState, Result,
 };
 use anyhow::anyhow;
@@ -49,7 +49,7 @@ use tokio::{
     time::Sleep,
 };
 use tower::ServiceBuilder;
-use tracing::{info_span, Instrument};
+use tracing::{info_span, instrument, Instrument};
 
 type MessageHandler =
     Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
@@ -244,12 +244,14 @@ impl Server {
         self: &Arc<Self>,
         connection: Connection,
         address: String,
-        user_id: UserId,
+        user: User,
         mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
         executor: E,
     ) -> impl Future<Output = Result<()>> {
         let mut this = self.clone();
-        let span = info_span!("handle connection", %user_id, %address);
+        let user_id = user.id;
+        let login = user.github_login;
+        let span = info_span!("handle connection", %user_id, %login, %address);
         async move {
             let (connection_id, handle_io, mut incoming_rx) = this
                 .peer
@@ -264,7 +266,7 @@ impl Server {
                 })
                 .await;
 
-            tracing::info!(%user_id, %connection_id, %address, "connection opened");
+            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 
             if let Some(send_connection_id) = send_connection_id.as_mut() {
                 let _ = send_connection_id.send(connection_id).await;
@@ -287,14 +289,14 @@ impl Server {
                 futures::select_biased! {
                     result = handle_io => {
                         if let Err(error) = result {
-                            tracing::error!(%error, "error handling I/O");
+                            tracing::error!(%error, %user_id, %login, %connection_id, %address, "error handling I/O");
                         }
                         break;
                     }
                     message = next_message => {
                         if let Some(message) = message {
                             let type_name = message.payload_type_name();
-                            let span = tracing::info_span!("receive message", %user_id, %connection_id, %address, type_name);
+                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
                             async {
                                 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
                                     let notifications = this.notifications.clone();
@@ -312,25 +314,27 @@ impl Server {
                                         handle_message.await;
                                     }
                                 } else {
-                                    tracing::error!("no message handler");
+                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
                                 }
                             }.instrument(span).await;
                         } else {
-                            tracing::info!(%user_id, %connection_id, %address, "connection closed");
+                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
                             break;
                         }
                     }
                 }
             }
 
+            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
             if let Err(error) = this.sign_out(connection_id).await {
-                tracing::error!(%error, "error signing out");
+                tracing::error!(%user_id, %login, %connection_id, %address, %error, "error signing out");
             }
 
             Ok(())
         }.instrument(span)
     }
 
+    #[instrument(skip(self), err)]
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
         self.peer.disconnect(connection_id);
         let removed_connection = self.store_mut().await.remove_connection(connection_id)?;
@@ -1420,7 +1424,7 @@ pub async fn handle_websocket_request(
     TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
     ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
     Extension(server): Extension<Arc<Server>>,
-    Extension(user_id): Extension<UserId>,
+    Extension(user): Extension<User>,
     ws: WebSocketUpgrade,
 ) -> axum::response::Response {
     if protocol_version != rpc::PROTOCOL_VERSION {
@@ -1440,7 +1444,7 @@ pub async fn handle_websocket_request(
         let connection = Connection::new(Box::pin(socket));
         async move {
             server
-                .handle_connection(connection, socket_address, user_id, None, RealExecutor)
+                .handle_connection(connection, socket_address, user, None, RealExecutor)
                 .await
                 .log_err();
         }
@@ -6451,6 +6455,7 @@ mod tests {
             let client_name = name.to_string();
             let mut client = Client::new(http.clone());
             let server = self.server.clone();
+            let db = self.app_state.db.clone();
             let connection_killers = self.connection_killers.clone();
             let forbid_connections = self.forbid_connections.clone();
             let (connection_id_tx, mut connection_id_rx) = mpsc::channel(16);
@@ -6471,6 +6476,7 @@ mod tests {
                     assert_eq!(credentials.access_token, "the-token");
 
                     let server = server.clone();
+                    let db = db.clone();
                     let connection_killers = connection_killers.clone();
                     let forbid_connections = forbid_connections.clone();
                     let client_name = client_name.clone();
@@ -6484,11 +6490,12 @@ mod tests {
                             let (client_conn, server_conn, killed) =
                                 Connection::in_memory(cx.background());
                             connection_killers.lock().insert(user_id, killed);
+                            let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
                             cx.background()
                                 .spawn(server.handle_connection(
                                     server_conn,
                                     client_name,
-                                    user_id,
+                                    user,
                                     Some(connection_id_tx),
                                     cx.background(),
                                 ))