Implement /rpc_server_snapshot endpoint

Nathan Sobo created

This returns a JSON snapshot of the state of the server

Change summary

crates/collab/src/api.rs          | 14 +++++++++++---
crates/collab/src/main.rs         |  6 ++++++
crates/collab/src/rpc.rs          | 26 ++++++++++++++++++++++++++
crates/collab/src/rpc/store.rs    | 11 +++++++++--
crates/rpc/src/peer.rs            | 18 ++++++++++++++++--
styles/src/styleTree/workspace.ts |  1 -
6 files changed, 68 insertions(+), 8 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     auth,
     db::{User, UserId},
-    rpc::ResultExt,
+    rpc::{self, ResultExt},
     AppState, Error, Result,
 };
 use anyhow::anyhow;
@@ -15,11 +15,12 @@ use axum::{
     Extension, Json, Router,
 };
 use serde::{Deserialize, Serialize};
+use serde_json::Value;
 use std::sync::Arc;
 use tower::ServiceBuilder;
 use tracing::instrument;
 
-pub fn routes(rpc_server: &Arc<crate::rpc::Server>, state: Arc<AppState>) -> Router<Body> {
+pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body> {
     Router::new()
         .route("/users", get(get_users).post(create_user))
         .route(
@@ -29,6 +30,7 @@ pub fn routes(rpc_server: &Arc<crate::rpc::Server>, state: Arc<AppState>) -> Rou
         .route("/users/:id/access_tokens", post(create_access_token))
         .route("/invite_codes/:code", get(get_user_for_invite_code))
         .route("/panic", post(trace_panic))
+        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
         .layer(
             ServiceBuilder::new()
                 .layer(Extension(state))
@@ -84,7 +86,7 @@ struct CreateUserParams {
 async fn create_user(
     Json(params): Json<CreateUserParams>,
     Extension(app): Extension<Arc<AppState>>,
-    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
+    Extension(rpc_server): Extension<Arc<rpc::Server>>,
 ) -> Result<Json<User>> {
     println!("{:?}", params);
 
@@ -177,6 +179,12 @@ async fn trace_panic(panic: Json<Panic>) -> Result<()> {
     Ok(())
 }
 
+async fn get_rpc_server_snapshot<'a>(
+    Extension(rpc_server): Extension<Arc<rpc::Server>>,
+) -> Result<Json<Value>> {
+    Ok(Json(serde_json::to_value(rpc_server.snapshot().await)?))
+}
+
 #[derive(Deserialize)]
 struct CreateAccessTokenQueryParams {
     public_key: String,

crates/collab/src/main.rs 🔗

@@ -104,6 +104,12 @@ impl From<hyper::Error> for Error {
     }
 }
 
+impl From<serde_json::Error> for Error {
+    fn from(error: serde_json::Error) -> Self {
+        Self::Internal(error.into())
+    }
+}
+
 impl IntoResponse for Error {
     fn into_response(self) -> axum::response::Response {
         match self {

crates/collab/src/rpc.rs 🔗

@@ -33,6 +33,7 @@ use rpc::{
     proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
     Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
 };
+use serde::{Serialize, Serializer};
 use std::{
     any::TypeId,
     future::Future,
@@ -85,6 +86,7 @@ pub struct Server {
     notifications: Option<mpsc::UnboundedSender<()>>,
 }
 
+
 pub trait Executor: Send + Clone {
     type Sleep: Send + Future;
     fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
@@ -107,6 +109,23 @@ struct StoreWriteGuard<'a> {
     _not_send: PhantomData<Rc<()>>,
 }
 
+#[derive(Serialize)]
+pub struct ServerSnapshot<'a> {
+    peer: &'a Peer,
+    #[serde(serialize_with = "serialize_deref")]
+    store: RwLockReadGuard<'a, Store>,
+}
+
+pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
+where
+    S: Serializer,
+    T: Deref<Target = U>,
+    U: Serialize
+{
+    Serialize::serialize(value.deref(), serializer)
+}
+  
+
 impl Server {
     pub fn new(
         app_state: Arc<AppState>,
@@ -1469,6 +1488,13 @@ impl Server {
             _not_send: PhantomData,
         }
     }
+    
+    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
+        ServerSnapshot {
+            store: self.store.read().await,
+            peer: &self.peer
+        }
+    }
 }
 
 impl<'a> Deref for StoreReadGuard<'a> {

crates/collab/src/rpc/store.rs 🔗

@@ -2,18 +2,21 @@ use crate::db::{self, ChannelId, UserId};
 use anyhow::{anyhow, Result};
 use collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
 use rpc::{proto, ConnectionId, Receipt};
+use serde::Serialize;
 use std::{collections::hash_map, mem, path::PathBuf};
 use tracing::instrument;
 
-#[derive(Default)]
+#[derive(Default, Serialize)]
 pub struct Store {
     connections: HashMap<ConnectionId, ConnectionState>,
     connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
     projects: HashMap<u64, Project>,
+    #[serde(skip)]
     channels: HashMap<ChannelId, Channel>,
     next_project_id: u64,
 }
 
+#[derive(Serialize)]
 struct ConnectionState {
     user_id: UserId,
     projects: HashSet<u64>,
@@ -21,21 +24,25 @@ struct ConnectionState {
     channels: HashSet<ChannelId>,
 }
 
+#[derive(Serialize)]
 pub struct Project {
     pub host_connection_id: ConnectionId,
     pub host_user_id: UserId,
     pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
+    #[serde(skip)]
     pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
     pub active_replica_ids: HashSet<ReplicaId>,
     pub worktrees: HashMap<u64, Worktree>,
     pub language_servers: Vec<proto::LanguageServer>,
 }
 
-#[derive(Default)]
+#[derive(Default, Serialize)]
 pub struct Worktree {
     pub root_name: String,
     pub visible: bool,
+    #[serde(skip)]
     pub entries: HashMap<u64, proto::Entry>,
+    #[serde(skip)]
     pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
     pub scan_id: u64,
 }

crates/rpc/src/peer.rs 🔗

@@ -10,6 +10,7 @@ use futures::{
     FutureExt, SinkExt, StreamExt,
 };
 use parking_lot::{Mutex, RwLock};
+use serde::{ser::SerializeStruct, Serialize};
 use smol_timeout::TimeoutExt;
 use std::sync::atomic::Ordering::SeqCst;
 use std::{
@@ -24,7 +25,7 @@ use std::{
 };
 use tracing::instrument;
 
-#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize)]
 pub struct ConnectionId(pub u32);
 
 impl fmt::Display for ConnectionId {
@@ -89,10 +90,12 @@ pub struct Peer {
     next_connection_id: AtomicU32,
 }
 
-#[derive(Clone)]
+#[derive(Clone, Serialize)]
 pub struct ConnectionState {
+    #[serde(skip)]
     outgoing_tx: mpsc::UnboundedSender<proto::Message>,
     next_message_id: Arc<AtomicU32>,
+    #[serde(skip)]
     response_channels:
         Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
 }
@@ -471,6 +474,17 @@ impl Peer {
     }
 }
 
+impl Serialize for Peer {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: serde::Serializer,
+    {
+        let mut state = serializer.serialize_struct("Peer", 2)?;
+        state.serialize_field("connections", &*self.connections.read())?;
+        state.end()
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

styles/src/styleTree/workspace.ts 🔗

@@ -8,7 +8,6 @@ export function workspaceBackground(theme: Theme) {
 }
 
 export default function workspace(theme: Theme) {
-
   const tab = {
     height: 32,
     background: workspaceBackground(theme),