@@ -39,6 +39,7 @@ use rpc::{
use serde::{Serialize, Serializer};
use std::{
any::TypeId,
+ fmt,
future::Future,
marker::PhantomData,
net::SocketAddr,
@@ -67,20 +68,63 @@ lazy_static! {
.unwrap();
}
-type MessageHandler = Box<
- dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>,
->;
+type MessageHandler =
+ Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
struct Response<R> {
- server: Arc<Server>,
+ peer: Arc<Peer>,
receipt: Receipt<R>,
responded: Arc<AtomicBool>,
}
+impl<R: RequestMessage> Response<R> {
+ fn send(self, payload: R::Response) -> Result<()> {
+ self.responded.store(true, SeqCst);
+ self.peer.respond(self.receipt, payload)?;
+ Ok(())
+ }
+}
+
+#[derive(Clone)]
struct Session {
user_id: UserId,
connection_id: ConnectionId,
db: Arc<Mutex<DbHandle>>,
+ peer: Arc<Peer>,
+ connection_pool: Arc<Mutex<ConnectionPool>>,
+ live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
+}
+
+impl Session {
+ async fn db(&self) -> MutexGuard<DbHandle> {
+ #[cfg(test)]
+ tokio::task::yield_now().await;
+ let guard = self.db.lock().await;
+ #[cfg(test)]
+ tokio::task::yield_now().await;
+ guard
+ }
+
+ async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
+ #[cfg(test)]
+ tokio::task::yield_now().await;
+ let guard = self.connection_pool.lock().await;
+ #[cfg(test)]
+ tokio::task::yield_now().await;
+ ConnectionPoolGuard {
+ guard,
+ _not_send: PhantomData,
+ }
+ }
+}
+
+impl fmt::Debug for Session {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Session")
+ .field("user_id", &self.user_id)
+ .field("connection_id", &self.connection_id)
+ .finish()
+ }
}
struct DbHandle(Arc<DefaultDb>);
@@ -93,17 +137,9 @@ impl Deref for DbHandle {
}
}
-impl<R: RequestMessage> Response<R> {
- fn send(self, payload: R::Response) -> Result<()> {
- self.responded.store(true, SeqCst);
- self.server.peer.respond(self.receipt, payload)?;
- Ok(())
- }
-}
-
pub struct Server {
peer: Arc<Peer>,
- pub(crate) connection_pool: Mutex<ConnectionPool>,
+ pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
}
@@ -148,76 +184,74 @@ impl Server {
};
server
- .add_request_handler(Server::ping)
- .add_request_handler(Server::create_room)
- .add_request_handler(Server::join_room)
- .add_message_handler(Server::leave_room)
- .add_request_handler(Server::call)
- .add_request_handler(Server::cancel_call)
- .add_message_handler(Server::decline_call)
- .add_request_handler(Server::update_participant_location)
- .add_request_handler(Server::share_project)
- .add_message_handler(Server::unshare_project)
- .add_request_handler(Server::join_project)
- .add_message_handler(Server::leave_project)
- .add_request_handler(Server::update_project)
- .add_request_handler(Server::update_worktree)
- .add_message_handler(Server::start_language_server)
- .add_message_handler(Server::update_language_server)
- .add_request_handler(Server::update_diagnostic_summary)
- .add_request_handler(Server::forward_project_request::<proto::GetHover>)
- .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
- .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
- .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
- .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
- .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
- .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
- .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
- .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
- .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
- .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
- .add_request_handler(
- Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
- )
- .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
- .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
- .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
- .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
- .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
- .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
- .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
- .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
- .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
- .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
- .add_message_handler(Server::create_buffer_for_peer)
- .add_request_handler(Server::update_buffer)
- .add_message_handler(Server::update_buffer_file)
- .add_message_handler(Server::buffer_reloaded)
- .add_message_handler(Server::buffer_saved)
- .add_request_handler(Server::save_buffer)
- .add_request_handler(Server::get_users)
- .add_request_handler(Server::fuzzy_search_users)
- .add_request_handler(Server::request_contact)
- .add_request_handler(Server::remove_contact)
- .add_request_handler(Server::respond_to_contact_request)
- .add_request_handler(Server::follow)
- .add_message_handler(Server::unfollow)
- .add_message_handler(Server::update_followers)
- .add_message_handler(Server::update_diff_base)
- .add_request_handler(Server::get_private_user_info);
+ .add_request_handler(ping)
+ .add_request_handler(create_room)
+ .add_request_handler(join_room)
+ .add_message_handler(leave_room)
+ .add_request_handler(call)
+ .add_request_handler(cancel_call)
+ .add_message_handler(decline_call)
+ .add_request_handler(update_participant_location)
+ .add_request_handler(share_project)
+ .add_message_handler(unshare_project)
+ .add_request_handler(join_project)
+ .add_message_handler(leave_project)
+ .add_request_handler(update_project)
+ .add_request_handler(update_worktree)
+ .add_message_handler(start_language_server)
+ .add_message_handler(update_language_server)
+ .add_request_handler(update_diagnostic_summary)
+ .add_request_handler(forward_project_request::<proto::GetHover>)
+ .add_request_handler(forward_project_request::<proto::GetDefinition>)
+ .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
+ .add_request_handler(forward_project_request::<proto::GetReferences>)
+ .add_request_handler(forward_project_request::<proto::SearchProject>)
+ .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
+ .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
+ .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
+ .add_request_handler(forward_project_request::<proto::OpenBufferById>)
+ .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
+ .add_request_handler(forward_project_request::<proto::GetCompletions>)
+ .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
+ .add_request_handler(forward_project_request::<proto::GetCodeActions>)
+ .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
+ .add_request_handler(forward_project_request::<proto::PrepareRename>)
+ .add_request_handler(forward_project_request::<proto::PerformRename>)
+ .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
+ .add_request_handler(forward_project_request::<proto::FormatBuffers>)
+ .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
+ .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
+ .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
+ .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
+ .add_message_handler(create_buffer_for_peer)
+ .add_request_handler(update_buffer)
+ .add_message_handler(update_buffer_file)
+ .add_message_handler(buffer_reloaded)
+ .add_message_handler(buffer_saved)
+ .add_request_handler(save_buffer)
+ .add_request_handler(get_users)
+ .add_request_handler(fuzzy_search_users)
+ .add_request_handler(request_contact)
+ .add_request_handler(remove_contact)
+ .add_request_handler(respond_to_contact_request)
+ .add_request_handler(follow)
+ .add_message_handler(unfollow)
+ .add_message_handler(update_followers)
+ .add_message_handler(update_diff_base)
+ .add_request_handler(get_private_user_info);
Arc::new(server)
}
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
- F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Session) -> Fut,
+ F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
let prev_handler = self.handlers.insert(
TypeId::of::<M>(),
- Box::new(move |server, envelope, session| {
+ Box::new(move |envelope, session| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let span = info_span!(
"handle message",
@@ -229,7 +263,7 @@ impl Server {
"message received"
);
});
- let future = (handler)(server, *envelope, session);
+ let future = (handler)(*envelope, session);
async move {
if let Err(error) = future.await {
tracing::error!(%error, "error handling message");
@@ -247,34 +281,33 @@ impl Server {
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
- F: 'static + Send + Sync + Fn(Arc<Self>, M, Session) -> Fut,
+ F: 'static + Send + Sync + Fn(M, Session) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
- self.add_handler(move |server, envelope, session| {
- handler(server, envelope.payload, session)
- });
+ self.add_handler(move |envelope, session| handler(envelope.payload, session));
self
}
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
- F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
+ F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
- self.add_handler(move |server, envelope, session| {
+ self.add_handler(move |envelope, session| {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
+ let peer = session.peer.clone();
let responded = Arc::new(AtomicBool::default());
let response = Response {
- server: server.clone(),
+ peer: peer.clone(),
responded: responded.clone(),
receipt,
};
- match (handler)(server.clone(), envelope.payload, response, session).await {
+ match (handler)(envelope.payload, response, session).await {
Ok(()) => {
if responded.load(std::sync::atomic::Ordering::SeqCst) {
Ok(())
@@ -283,7 +316,7 @@ impl Server {
}
}
Err(error) => {
- server.peer.respond_with_error(
+ peer.respond_with_error(
receipt,
proto::Error {
message: error.to_string(),
@@ -304,7 +337,7 @@ impl Server {
mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: E,
) -> impl Future<Output = Result<()>> {
- let mut this = self.clone();
+ let this = self.clone();
let user_id = user.id;
let login = user.github_login;
let span = info_span!("handle connection", %user_id, %login, %address);
@@ -340,7 +373,7 @@ impl Server {
).await?;
{
- let mut pool = this.connection_pool().await;
+ let mut pool = this.connection_pool.lock().await;
pool.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
@@ -356,13 +389,19 @@ impl Server {
this.peer.send(connection_id, incoming_call)?;
}
- this.update_user_contacts(user_id).await?;
+ let session = Session {
+ user_id,
+ connection_id,
+ db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
+ peer: this.peer.clone(),
+ connection_pool: this.connection_pool.clone(),
+ live_kit_client: this.app_state.live_kit_client.clone()
+ };
+ update_user_contacts(user_id, &session).await?;
let handle_io = handle_io.fuse();
futures::pin_mut!(handle_io);
- let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone())));
-
// Handlers for foreground messages are pushed into the following `FuturesUnordered`.
// This prevents deadlocks when e.g., client A performs a request to client B and
// client B performs a request to client A. If both clients stop processing further
@@ -390,12 +429,7 @@ impl Server {
let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
let is_background = message.is_background();
- let session = Session {
- user_id,
- connection_id,
- db: db.clone(),
- };
- let handle_message = (handler)(this.clone(), message, session);
+ let handle_message = (handler)(message, session.clone());
drop(span_enter);
let handle_message = handle_message.instrument(span);
@@ -417,7 +451,7 @@ impl Server {
drop(foreground_message_handlers);
tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
- if let Err(error) = this.sign_out(connection_id, user_id).await {
+ if let Err(error) = sign_out(session).await {
tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
}
@@ -425,40 +459,6 @@ impl Server {
}.instrument(span)
}
- #[instrument(skip(self), err)]
- async fn sign_out(
- self: &mut Arc<Self>,
- connection_id: ConnectionId,
- user_id: UserId,
- ) -> Result<()> {
- self.peer.disconnect(connection_id);
- let decline_calls = {
- let mut pool = self.connection_pool().await;
- pool.remove_connection(connection_id)?;
- let mut connections = pool.user_connection_ids(user_id);
- connections.next().is_none()
- };
-
- self.leave_room_for_connection(connection_id, user_id)
- .await
- .trace_err();
- if decline_calls {
- if let Some(room) = self
- .app_state
- .db
- .decline_call(None, user_id)
- .await
- .trace_err()
- {
- self.room_updated(&room);
- }
- }
-
- self.update_user_contacts(user_id).await?;
-
- Ok(())
- }
-
pub async fn invite_code_redeemed(
self: &Arc<Self>,
inviter_id: UserId,
@@ -466,7 +466,7 @@ impl Server {
) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code {
- let pool = self.connection_pool().await;
+ let pool = self.connection_pool.lock().await;
let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send(
@@ -492,7 +492,7 @@ impl Server {
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
if let Some(invite_code) = &user.invite_code {
- let pool = self.connection_pool().await;
+ let pool = self.connection_pool.lock().await;
for connection_id in pool.user_connection_ids(user_id) {
self.peer.send(
connection_id,
@@ -510,1360 +510,1194 @@ impl Server {
Ok(())
}
- async fn ping(
- self: Arc<Server>,
- _: proto::Ping,
- response: Response<proto::Ping>,
- _session: Session,
- ) -> Result<()> {
- response.send(proto::Ack {})?;
- Ok(())
+ pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
+ ServerSnapshot {
+ connection_pool: ConnectionPoolGuard {
+ guard: self.connection_pool.lock().await,
+ _not_send: PhantomData,
+ },
+ peer: &self.peer,
+ }
}
+}
- async fn create_room(
- self: Arc<Server>,
- _request: proto::CreateRoom,
- response: Response<proto::CreateRoom>,
- session: Session,
- ) -> Result<()> {
- let room = self
- .app_state
- .db
- .create_room(session.user_id, session.connection_id)
- .await?;
-
- let live_kit_connection_info =
- if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
- if let Some(_) = live_kit
- .create_room(room.live_kit_room.clone())
- .await
- .trace_err()
- {
- if let Some(token) = live_kit
- .room_token(&room.live_kit_room, &session.connection_id.to_string())
- .trace_err()
- {
- Some(proto::LiveKitConnectionInfo {
- server_url: live_kit.url().into(),
- token,
- })
- } else {
- None
- }
- } else {
- None
- }
- } else {
- None
- };
+impl<'a> Deref for ConnectionPoolGuard<'a> {
+ type Target = ConnectionPool;
- response.send(proto::CreateRoomResponse {
- room: Some(room),
- live_kit_connection_info,
- })?;
- self.update_user_contacts(session.user_id).await?;
- Ok(())
+ fn deref(&self) -> &Self::Target {
+ &*self.guard
}
+}
- async fn join_room(
- self: Arc<Server>,
- request: proto::JoinRoom,
- response: Response<proto::JoinRoom>,
- session: Session,
- ) -> Result<()> {
- let room = self
- .app_state
- .db
- .join_room(
- RoomId::from_proto(request.id),
- session.user_id,
- session.connection_id,
- )
- .await?;
- for connection_id in self
- .connection_pool()
- .await
- .user_connection_ids(session.user_id)
- {
- self.peer
- .send(connection_id, proto::CallCanceled {})
- .trace_err();
- }
+impl<'a> DerefMut for ConnectionPoolGuard<'a> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut *self.guard
+ }
+}
- let live_kit_connection_info =
- if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
- if let Some(token) = live_kit
- .room_token(&room.live_kit_room, &session.connection_id.to_string())
- .trace_err()
- {
- Some(proto::LiveKitConnectionInfo {
- server_url: live_kit.url().into(),
- token,
- })
- } else {
- None
- }
- } else {
- None
- };
+impl<'a> Drop for ConnectionPoolGuard<'a> {
+ fn drop(&mut self) {
+ #[cfg(test)]
+ self.check_invariants();
+ }
+}
- self.room_updated(&room);
- response.send(proto::JoinRoomResponse {
- room: Some(room),
- live_kit_connection_info,
- })?;
+impl Executor for RealExecutor {
+ type Sleep = Sleep;
- self.update_user_contacts(session.user_id).await?;
- Ok(())
+ fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
+ tokio::task::spawn(future);
}
- async fn leave_room(
- self: Arc<Server>,
- _message: proto::LeaveRoom,
- session: Session,
- ) -> Result<()> {
- self.leave_room_for_connection(session.connection_id, session.user_id)
- .await
+ fn sleep(&self, duration: Duration) -> Self::Sleep {
+ tokio::time::sleep(duration)
}
+}
- async fn leave_room_for_connection(
- self: &Arc<Server>,
- leaving_connection_id: ConnectionId,
- leaving_user_id: UserId,
- ) -> Result<()> {
- let mut contacts_to_update = HashSet::default();
-
- let Some(left_room) = self.app_state.db.leave_room(leaving_connection_id).await? else {
- return Err(anyhow!("no room to leave"))?;
- };
- contacts_to_update.insert(leaving_user_id);
-
- for project in left_room.left_projects.into_values() {
- for connection_id in project.connection_ids {
- if project.host_user_id == leaving_user_id {
- self.peer
- .send(
- connection_id,
- proto::UnshareProject {
- project_id: project.id.to_proto(),
- },
- )
- .trace_err();
- } else {
- self.peer
- .send(
- connection_id,
- proto::RemoveProjectCollaborator {
- project_id: project.id.to_proto(),
- peer_id: leaving_connection_id.0,
- },
- )
- .trace_err();
- }
- }
-
- self.peer
- .send(
- leaving_connection_id,
- proto::UnshareProject {
- project_id: project.id.to_proto(),
- },
- )
- .trace_err();
+fn broadcast<F>(
+ sender_id: ConnectionId,
+ receiver_ids: impl IntoIterator<Item = ConnectionId>,
+ mut f: F,
+) where
+ F: FnMut(ConnectionId) -> anyhow::Result<()>,
+{
+ for receiver_id in receiver_ids {
+ if receiver_id != sender_id {
+ f(receiver_id).trace_err();
}
+ }
+}
- self.room_updated(&left_room.room);
- {
- let pool = self.connection_pool().await;
- for canceled_user_id in left_room.canceled_calls_to_user_ids {
- for connection_id in pool.user_connection_ids(canceled_user_id) {
- self.peer
- .send(connection_id, proto::CallCanceled {})
- .trace_err();
- }
- contacts_to_update.insert(canceled_user_id);
- }
- }
+lazy_static! {
+ static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
+}
- for contact_user_id in contacts_to_update {
- self.update_user_contacts(contact_user_id).await?;
- }
+pub struct ProtocolVersion(u32);
- if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
- live_kit
- .remove_participant(
- left_room.room.live_kit_room.clone(),
- leaving_connection_id.to_string(),
- )
- .await
- .trace_err();
+impl Header for ProtocolVersion {
+ fn name() -> &'static HeaderName {
+ &ZED_PROTOCOL_VERSION
+ }
- if left_room.room.participants.is_empty() {
- live_kit
- .delete_room(left_room.room.live_kit_room)
- .await
- .trace_err();
- }
- }
+ fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
+ where
+ Self: Sized,
+ I: Iterator<Item = &'i axum::http::HeaderValue>,
+ {
+ let version = values
+ .next()
+ .ok_or_else(axum::headers::Error::invalid)?
+ .to_str()
+ .map_err(|_| axum::headers::Error::invalid())?
+ .parse()
+ .map_err(|_| axum::headers::Error::invalid())?;
+ Ok(Self(version))
+ }
- Ok(())
+ fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
+ values.extend([self.0.to_string().parse().unwrap()]);
}
+}
- async fn call(
- self: Arc<Server>,
- request: proto::Call,
- response: Response<proto::Call>,
- session: Session,
- ) -> Result<()> {
- let room_id = RoomId::from_proto(request.room_id);
- let calling_user_id = session.user_id;
- let calling_connection_id = session.connection_id;
- let called_user_id = UserId::from_proto(request.called_user_id);
- let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
- if !self
- .app_state
- .db
- .has_contact(calling_user_id, called_user_id)
- .await?
- {
- return Err(anyhow!("cannot call a user who isn't a contact"))?;
+pub fn routes(server: Arc<Server>) -> Router<Body> {
+ Router::new()
+ .route("/rpc", get(handle_websocket_request))
+ .layer(
+ ServiceBuilder::new()
+ .layer(Extension(server.app_state.clone()))
+ .layer(middleware::from_fn(auth::validate_header)),
+ )
+ .route("/metrics", get(handle_metrics))
+ .layer(Extension(server))
+}
+
+pub async fn handle_websocket_request(
+ TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
+ ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
+ Extension(server): Extension<Arc<Server>>,
+ Extension(user): Extension<User>,
+ ws: WebSocketUpgrade,
+) -> axum::response::Response {
+ if protocol_version != rpc::PROTOCOL_VERSION {
+ return (
+ StatusCode::UPGRADE_REQUIRED,
+ "client must be upgraded".to_string(),
+ )
+ .into_response();
+ }
+ let socket_address = socket_address.to_string();
+ ws.on_upgrade(move |socket| {
+ use util::ResultExt;
+ let socket = socket
+ .map_ok(to_tungstenite_message)
+ .err_into()
+ .with(|message| async move { Ok(to_axum_message(message)) });
+ let connection = Connection::new(Box::pin(socket));
+ async move {
+ server
+ .handle_connection(connection, socket_address, user, None, RealExecutor)
+ .await
+ .log_err();
}
+ })
+}
- let (room, incoming_call) = self
- .app_state
- .db
- .call(
- room_id,
- calling_user_id,
- calling_connection_id,
- called_user_id,
- initial_project_id,
- )
- .await?;
- self.room_updated(&room);
- self.update_user_contacts(called_user_id).await?;
+pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
+ let connections = server
+ .connection_pool
+ .lock()
+ .await
+ .connections()
+ .filter(|connection| !connection.admin)
+ .count();
- let mut calls = self
- .connection_pool()
- .await
- .user_connection_ids(called_user_id)
- .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
- .collect::<FuturesUnordered<_>>();
-
- while let Some(call_response) = calls.next().await {
- match call_response.as_ref() {
- Ok(_) => {
- response.send(proto::Ack {})?;
- return Ok(());
- }
- Err(_) => {
- call_response.trace_err();
- }
- }
- }
+ METRIC_CONNECTIONS.set(connections as _);
- let room = self
- .app_state
- .db
- .call_failed(room_id, called_user_id)
- .await?;
- self.room_updated(&room);
- self.update_user_contacts(called_user_id).await?;
+ let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
+ METRIC_SHARED_PROJECTS.set(shared_projects as _);
- Err(anyhow!("failed to ring user"))?
- }
+ let encoder = prometheus::TextEncoder::new();
+ let metric_families = prometheus::gather();
+ let encoded_metrics = encoder
+ .encode_to_string(&metric_families)
+ .map_err(|err| anyhow!("{}", err))?;
+ Ok(encoded_metrics)
+}
- async fn cancel_call(
- self: Arc<Server>,
- request: proto::CancelCall,
- response: Response<proto::CancelCall>,
- session: Session,
- ) -> Result<()> {
- let called_user_id = UserId::from_proto(request.called_user_id);
- let room_id = RoomId::from_proto(request.room_id);
- let room = self
- .app_state
- .db
- .cancel_call(Some(room_id), session.connection_id, called_user_id)
- .await?;
- for connection_id in self
- .connection_pool()
+#[instrument(err)]
+async fn sign_out(session: Session) -> Result<()> {
+ session.peer.disconnect(session.connection_id);
+ let decline_calls = {
+ let mut pool = session.connection_pool().await;
+ pool.remove_connection(session.connection_id)?;
+ let mut connections = pool.user_connection_ids(session.user_id);
+ connections.next().is_none()
+ };
+
+ leave_room_for_session(&session).await.trace_err();
+ if decline_calls {
+ if let Some(room) = session
+ .db()
.await
- .user_connection_ids(called_user_id)
+ .decline_call(None, session.user_id)
+ .await
+ .trace_err()
{
- self.peer
- .send(connection_id, proto::CallCanceled {})
- .trace_err();
+ room_updated(&room, &session);
}
- self.room_updated(&room);
- response.send(proto::Ack {})?;
-
- self.update_user_contacts(called_user_id).await?;
- Ok(())
}
- async fn decline_call(
- self: Arc<Server>,
- message: proto::DeclineCall,
- session: Session,
- ) -> Result<()> {
- let room_id = RoomId::from_proto(message.room_id);
- let room = self
- .app_state
- .db
- .decline_call(Some(room_id), session.user_id)
- .await?;
- for connection_id in self
- .connection_pool()
+ update_user_contacts(session.user_id, &session).await?;
+
+ Ok(())
+}
+
+async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
+ response.send(proto::Ack {})?;
+ Ok(())
+}
+
+async fn create_room(
+ _request: proto::CreateRoom,
+ response: Response<proto::CreateRoom>,
+ session: Session,
+) -> Result<()> {
+ let room = session
+ .db()
+ .await
+ .create_room(session.user_id, session.connection_id)
+ .await?;
+
+ let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
+ if let Some(_) = live_kit
+ .create_room(room.live_kit_room.clone())
.await
- .user_connection_ids(session.user_id)
+ .trace_err()
{
- self.peer
- .send(connection_id, proto::CallCanceled {})
- .trace_err();
- }
- self.room_updated(&room);
- self.update_user_contacts(session.user_id).await?;
- Ok(())
- }
-
- async fn update_participant_location(
- self: Arc<Server>,
- request: proto::UpdateParticipantLocation,
- response: Response<proto::UpdateParticipantLocation>,
- session: Session,
- ) -> Result<()> {
- let room_id = RoomId::from_proto(request.room_id);
- let location = request
- .location
- .ok_or_else(|| anyhow!("invalid location"))?;
- let room = self
- .app_state
- .db
- .update_room_participant_location(room_id, session.connection_id, location)
- .await?;
- self.room_updated(&room);
- response.send(proto::Ack {})?;
- Ok(())
- }
-
- fn room_updated(&self, room: &proto::Room) {
- for participant in &room.participants {
- self.peer
- .send(
- ConnectionId(participant.peer_id),
- proto::RoomUpdated {
- room: Some(room.clone()),
- },
- )
- .trace_err();
- }
- }
-
- async fn share_project(
- self: Arc<Server>,
- request: proto::ShareProject,
- response: Response<proto::ShareProject>,
- session: Session,
- ) -> Result<()> {
- let (project_id, room) = self
- .app_state
- .db
- .share_project(
- RoomId::from_proto(request.room_id),
- session.connection_id,
- &request.worktrees,
- )
- .await?;
- response.send(proto::ShareProjectResponse {
- project_id: project_id.to_proto(),
- })?;
- self.room_updated(&room);
-
- Ok(())
- }
-
- async fn unshare_project(
- self: Arc<Server>,
- message: proto::UnshareProject,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(message.project_id);
-
- let (room, guest_connection_ids) = self
- .app_state
- .db
- .unshare_project(project_id, session.connection_id)
- .await?;
-
- broadcast(session.connection_id, guest_connection_ids, |conn_id| {
- self.peer.send(conn_id, message.clone())
- });
- self.room_updated(&room);
-
- Ok(())
- }
-
- async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
- let contacts = self.app_state.db.get_contacts(user_id).await?;
- let busy = self.app_state.db.is_user_busy(user_id).await?;
- let pool = self.connection_pool().await;
- let updated_contact = contact_for_user(user_id, false, busy, &pool);
- for contact in contacts {
- if let db::Contact::Accepted {
- user_id: contact_user_id,
- ..
- } = contact
+ if let Some(token) = live_kit
+ .room_token(&room.live_kit_room, &session.connection_id.to_string())
+ .trace_err()
{
- for contact_conn_id in pool.user_connection_ids(contact_user_id) {
- self.peer
- .send(
- contact_conn_id,
- proto::UpdateContacts {
- contacts: vec![updated_contact.clone()],
- remove_contacts: Default::default(),
- incoming_requests: Default::default(),
- remove_incoming_requests: Default::default(),
- outgoing_requests: Default::default(),
- remove_outgoing_requests: Default::default(),
- },
- )
- .trace_err();
- }
- }
- }
- Ok(())
- }
-
- async fn join_project(
- self: Arc<Server>,
- request: proto::JoinProject,
- response: Response<proto::JoinProject>,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let guest_user_id = session.user_id;
-
- tracing::info!(%project_id, "join project");
-
- let (project, replica_id) = self
- .app_state
- .db
- .join_project(project_id, session.connection_id)
- .await?;
-
- let collaborators = project
- .collaborators
- .iter()
- .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
- .map(|collaborator| proto::Collaborator {
- peer_id: collaborator.connection_id as u32,
- replica_id: collaborator.replica_id.0 as u32,
- user_id: collaborator.user_id.to_proto(),
- })
- .collect::<Vec<_>>();
- let worktrees = project
- .worktrees
- .iter()
- .map(|(id, worktree)| proto::WorktreeMetadata {
- id: id.to_proto(),
- root_name: worktree.root_name.clone(),
- visible: worktree.visible,
- abs_path: worktree.abs_path.clone(),
- })
- .collect::<Vec<_>>();
-
- for collaborator in &collaborators {
- self.peer
- .send(
- ConnectionId(collaborator.peer_id),
- proto::AddProjectCollaborator {
- project_id: project_id.to_proto(),
- collaborator: Some(proto::Collaborator {
- peer_id: session.connection_id.0,
- replica_id: replica_id.0 as u32,
- user_id: guest_user_id.to_proto(),
- }),
- },
- )
- .trace_err();
- }
-
- // First, we send the metadata associated with each worktree.
- response.send(proto::JoinProjectResponse {
- worktrees: worktrees.clone(),
- replica_id: replica_id.0 as u32,
- collaborators: collaborators.clone(),
- language_servers: project.language_servers.clone(),
- })?;
-
- for (worktree_id, worktree) in project.worktrees {
- #[cfg(any(test, feature = "test-support"))]
- const MAX_CHUNK_SIZE: usize = 2;
- #[cfg(not(any(test, feature = "test-support")))]
- const MAX_CHUNK_SIZE: usize = 256;
-
- // Stream this worktree's entries.
- let message = proto::UpdateWorktree {
- project_id: project_id.to_proto(),
- worktree_id: worktree_id.to_proto(),
- abs_path: worktree.abs_path.clone(),
- root_name: worktree.root_name,
- updated_entries: worktree.entries,
- removed_entries: Default::default(),
- scan_id: worktree.scan_id,
- is_last_update: worktree.is_complete,
- };
- for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
- self.peer.send(session.connection_id, update.clone())?;
- }
-
- // Stream this worktree's diagnostics.
- for summary in worktree.diagnostic_summaries {
- self.peer.send(
- session.connection_id,
- proto::UpdateDiagnosticSummary {
- project_id: project_id.to_proto(),
- worktree_id: worktree.id.to_proto(),
- summary: Some(summary),
- },
- )?;
+ Some(proto::LiveKitConnectionInfo {
+ server_url: live_kit.url().into(),
+ token,
+ })
+ } else {
+ None
}
+ } else {
+ None
}
+ } else {
+ None
+ };
+
+ response.send(proto::CreateRoomResponse {
+ room: Some(room),
+ live_kit_connection_info,
+ })?;
+ update_user_contacts(session.user_id, &session).await?;
+ Ok(())
+}
- for language_server in &project.language_servers {
- self.peer.send(
- session.connection_id,
- proto::UpdateLanguageServer {
- project_id: project_id.to_proto(),
- language_server_id: language_server.id,
- variant: Some(
- proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
- proto::LspDiskBasedDiagnosticsUpdated {},
- ),
- ),
- },
- )?;
- }
-
- Ok(())
+async fn join_room(
+ request: proto::JoinRoom,
+ response: Response<proto::JoinRoom>,
+ session: Session,
+) -> Result<()> {
+ let room = session
+ .db()
+ .await
+ .join_room(
+ RoomId::from_proto(request.id),
+ session.user_id,
+ session.connection_id,
+ )
+ .await?;
+ for connection_id in session
+ .connection_pool()
+ .await
+ .user_connection_ids(session.user_id)
+ {
+ session
+ .peer
+ .send(connection_id, proto::CallCanceled {})
+ .trace_err();
}
- async fn leave_project(
- self: Arc<Server>,
- request: proto::LeaveProject,
- session: Session,
- ) -> Result<()> {
- let sender_id = session.connection_id;
- let project_id = ProjectId::from_proto(request.project_id);
- let project;
+ let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
+ if let Some(token) = live_kit
+ .room_token(&room.live_kit_room, &session.connection_id.to_string())
+ .trace_err()
{
- project = self
- .app_state
- .db
- .leave_project(project_id, sender_id)
- .await?;
- tracing::info!(
- %project_id,
- host_user_id = %project.host_user_id,
- host_connection_id = %project.host_connection_id,
- "leave project"
- );
-
- broadcast(sender_id, project.connection_ids, |conn_id| {
- self.peer.send(
- conn_id,
- proto::RemoveProjectCollaborator {
- project_id: project_id.to_proto(),
- peer_id: sender_id.0,
- },
- )
- });
+ Some(proto::LiveKitConnectionInfo {
+ server_url: live_kit.url().into(),
+ token,
+ })
+ } else {
+ None
}
+ } else {
+ None
+ };
+
+ room_updated(&room, &session);
+ response.send(proto::JoinRoomResponse {
+ room: Some(room),
+ live_kit_connection_info,
+ })?;
+
+ update_user_contacts(session.user_id, &session).await?;
+ Ok(())
+}
- Ok(())
- }
-
- async fn update_project(
- self: Arc<Server>,
- request: proto::UpdateProject,
- response: Response<proto::UpdateProject>,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let (room, guest_connection_ids) = self
- .app_state
- .db
- .update_project(project_id, session.connection_id, &request.worktrees)
- .await?;
- broadcast(
- session.connection_id,
- guest_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- self.room_updated(&room);
- response.send(proto::Ack {})?;
+async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
+ leave_room_for_session(&session).await
+}
- Ok(())
+async fn call(
+ request: proto::Call,
+ response: Response<proto::Call>,
+ session: Session,
+) -> Result<()> {
+ let room_id = RoomId::from_proto(request.room_id);
+ let calling_user_id = session.user_id;
+ let calling_connection_id = session.connection_id;
+ let called_user_id = UserId::from_proto(request.called_user_id);
+ let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
+ if !session
+ .db()
+ .await
+ .has_contact(calling_user_id, called_user_id)
+ .await?
+ {
+ return Err(anyhow!("cannot call a user who isn't a contact"))?;
}
- async fn update_worktree(
- self: Arc<Server>,
- request: proto::UpdateWorktree,
- response: Response<proto::UpdateWorktree>,
- session: Session,
- ) -> Result<()> {
- let guest_connection_ids = self
- .app_state
- .db
- .update_worktree(&request, session.connection_id)
- .await?;
+ let (room, incoming_call) = session
+ .db()
+ .await
+ .call(
+ room_id,
+ calling_user_id,
+ calling_connection_id,
+ called_user_id,
+ initial_project_id,
+ )
+ .await?;
+ room_updated(&room, &session);
+ update_user_contacts(called_user_id, &session).await?;
- broadcast(
- session.connection_id,
- guest_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- response.send(proto::Ack {})?;
- Ok(())
+ let mut calls = session
+ .connection_pool()
+ .await
+ .user_connection_ids(called_user_id)
+ .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
+ .collect::<FuturesUnordered<_>>();
+
+ while let Some(call_response) = calls.next().await {
+ match call_response.as_ref() {
+ Ok(_) => {
+ response.send(proto::Ack {})?;
+ return Ok(());
+ }
+ Err(_) => {
+ call_response.trace_err();
+ }
+ }
}
- async fn update_diagnostic_summary(
- self: Arc<Server>,
- request: proto::UpdateDiagnosticSummary,
- response: Response<proto::UpdateDiagnosticSummary>,
- session: Session,
- ) -> Result<()> {
- let guest_connection_ids = self
- .app_state
- .db
- .update_diagnostic_summary(&request, session.connection_id)
- .await?;
-
- broadcast(
- session.connection_id,
- guest_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
-
- response.send(proto::Ack {})?;
- Ok(())
- }
+ let room = session
+ .db()
+ .await
+ .call_failed(room_id, called_user_id)
+ .await?;
+ room_updated(&room, &session);
+ update_user_contacts(called_user_id, &session).await?;
- async fn start_language_server(
- self: Arc<Server>,
- request: proto::StartLanguageServer,
- session: Session,
- ) -> Result<()> {
- let guest_connection_ids = self
- .app_state
- .db
- .start_language_server(&request, session.connection_id)
- .await?;
+ Err(anyhow!("failed to ring user"))?
+}
- broadcast(
- session.connection_id,
- guest_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
+async fn cancel_call(
+ request: proto::CancelCall,
+ response: Response<proto::CancelCall>,
+ session: Session,
+) -> Result<()> {
+ let called_user_id = UserId::from_proto(request.called_user_id);
+ let room_id = RoomId::from_proto(request.room_id);
+ let room = session
+ .db()
+ .await
+ .cancel_call(Some(room_id), session.connection_id, called_user_id)
+ .await?;
+ for connection_id in session
+ .connection_pool()
+ .await
+ .user_connection_ids(called_user_id)
+ {
+ session
+ .peer
+ .send(connection_id, proto::CallCanceled {})
+ .trace_err();
}
+ room_updated(&room, &session);
+ response.send(proto::Ack {})?;
- async fn update_language_server(
- self: Arc<Server>,
- request: proto::UpdateLanguageServer,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
- }
+ update_user_contacts(called_user_id, &session).await?;
+ Ok(())
+}
- async fn forward_project_request<T>(
- self: Arc<Server>,
- request: T,
- response: Response<T>,
- session: Session,
- ) -> Result<()>
- where
- T: EntityMessage + RequestMessage,
+async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
+ let room_id = RoomId::from_proto(message.room_id);
+ let room = session
+ .db()
+ .await
+ .decline_call(Some(room_id), session.user_id)
+ .await?;
+ for connection_id in session
+ .connection_pool()
+ .await
+ .user_connection_ids(session.user_id)
{
- let project_id = ProjectId::from_proto(request.remote_entity_id());
- let collaborators = self
- .app_state
- .db
- .project_collaborators(project_id, session.connection_id)
- .await?;
- let host = collaborators
- .iter()
- .find(|collaborator| collaborator.is_host)
- .ok_or_else(|| anyhow!("host not found"))?;
-
- let payload = self
+ session
.peer
- .forward_request(
- session.connection_id,
- ConnectionId(host.connection_id as u32),
- request,
- )
- .await?;
-
- response.send(payload)?;
- Ok(())
+ .send(connection_id, proto::CallCanceled {})
+ .trace_err();
}
+ room_updated(&room, &session);
+ update_user_contacts(session.user_id, &session).await?;
+ Ok(())
+}
- async fn save_buffer(
- self: Arc<Server>,
- request: proto::SaveBuffer,
- response: Response<proto::SaveBuffer>,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let collaborators = self
- .app_state
- .db
- .project_collaborators(project_id, session.connection_id)
- .await?;
- let host = collaborators
- .into_iter()
- .find(|collaborator| collaborator.is_host)
- .ok_or_else(|| anyhow!("host not found"))?;
- let host_connection_id = ConnectionId(host.connection_id as u32);
- let response_payload = self
- .peer
- .forward_request(session.connection_id, host_connection_id, request.clone())
- .await?;
-
- let mut collaborators = self
- .app_state
- .db
- .project_collaborators(project_id, session.connection_id)
- .await?;
- collaborators
- .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
- let project_connection_ids = collaborators
- .into_iter()
- .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
- broadcast(host_connection_id, project_connection_ids, |conn_id| {
- self.peer
- .forward_send(host_connection_id, conn_id, response_payload.clone())
- });
- response.send(response_payload)?;
- Ok(())
- }
+async fn update_participant_location(
+ request: proto::UpdateParticipantLocation,
+ response: Response<proto::UpdateParticipantLocation>,
+ session: Session,
+) -> Result<()> {
+ let room_id = RoomId::from_proto(request.room_id);
+ let location = request
+ .location
+ .ok_or_else(|| anyhow!("invalid location"))?;
+ let room = session
+ .db()
+ .await
+ .update_room_participant_location(room_id, session.connection_id, location)
+ .await?;
+ room_updated(&room, &session);
+ response.send(proto::Ack {})?;
+ Ok(())
+}
- async fn create_buffer_for_peer(
- self: Arc<Server>,
- request: proto::CreateBufferForPeer,
- session: Session,
- ) -> Result<()> {
- self.peer.forward_send(
+async fn share_project(
+ request: proto::ShareProject,
+ response: Response<proto::ShareProject>,
+ session: Session,
+) -> Result<()> {
+ let (project_id, room) = session
+ .db()
+ .await
+ .share_project(
+ RoomId::from_proto(request.room_id),
session.connection_id,
- ConnectionId(request.peer_id),
- request,
- )?;
- Ok(())
- }
+ &request.worktrees,
+ )
+ .await?;
+ response.send(proto::ShareProjectResponse {
+ project_id: project_id.to_proto(),
+ })?;
+ room_updated(&room, &session);
- async fn update_buffer(
- self: Arc<Server>,
- request: proto::UpdateBuffer,
- response: Response<proto::UpdateBuffer>,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
+ Ok(())
+}
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- response.send(proto::Ack {})?;
- Ok(())
- }
+async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(message.project_id);
- async fn update_buffer_file(
- self: Arc<Server>,
- request: proto::UpdateBufferFile,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
+ let (room, guest_connection_ids) = session
+ .db()
+ .await
+ .unshare_project(project_id, session.connection_id)
+ .await?;
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
- }
+ broadcast(session.connection_id, guest_connection_ids, |conn_id| {
+ session.peer.send(conn_id, message.clone())
+ });
+ room_updated(&room, &session);
- async fn buffer_reloaded(
- self: Arc<Server>,
- request: proto::BufferReloaded,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
- }
+ Ok(())
+}
- async fn buffer_saved(
- self: Arc<Server>,
- request: proto::BufferSaved,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
- }
+async fn join_project(
+ request: proto::JoinProject,
+ response: Response<proto::JoinProject>,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let guest_user_id = session.user_id;
- async fn follow(
- self: Arc<Self>,
- request: proto::Follow,
- response: Response<proto::Follow>,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let leader_id = ConnectionId(request.leader_id);
- let follower_id = session.connection_id;
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
+ tracing::info!(%project_id, "join project");
- if !project_connection_ids.contains(&leader_id) {
- Err(anyhow!("no such peer"))?;
- }
+ let (project, replica_id) = session
+ .db()
+ .await
+ .join_project(project_id, session.connection_id)
+ .await?;
+
+ let collaborators = project
+ .collaborators
+ .iter()
+ .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
+ .map(|collaborator| proto::Collaborator {
+ peer_id: collaborator.connection_id as u32,
+ replica_id: collaborator.replica_id.0 as u32,
+ user_id: collaborator.user_id.to_proto(),
+ })
+ .collect::<Vec<_>>();
+ let worktrees = project
+ .worktrees
+ .iter()
+ .map(|(id, worktree)| proto::WorktreeMetadata {
+ id: id.to_proto(),
+ root_name: worktree.root_name.clone(),
+ visible: worktree.visible,
+ abs_path: worktree.abs_path.clone(),
+ })
+ .collect::<Vec<_>>();
- let mut response_payload = self
+ for collaborator in &collaborators {
+ session
.peer
- .forward_request(session.connection_id, leader_id, request)
- .await?;
- response_payload
- .views
- .retain(|view| view.leader_id != Some(follower_id.0));
- response.send(response_payload)?;
- Ok(())
+ .send(
+ ConnectionId(collaborator.peer_id),
+ proto::AddProjectCollaborator {
+ project_id: project_id.to_proto(),
+ collaborator: Some(proto::Collaborator {
+ peer_id: session.connection_id.0,
+ replica_id: replica_id.0 as u32,
+ user_id: guest_user_id.to_proto(),
+ }),
+ },
+ )
+ .trace_err();
}
- async fn unfollow(self: Arc<Self>, request: proto::Unfollow, session: Session) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let leader_id = ConnectionId(request.leader_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
- if !project_connection_ids.contains(&leader_id) {
- Err(anyhow!("no such peer"))?;
- }
- self.peer
- .forward_send(session.connection_id, leader_id, request)?;
- Ok(())
- }
+ // First, we send the metadata associated with each worktree.
+ response.send(proto::JoinProjectResponse {
+ worktrees: worktrees.clone(),
+ replica_id: replica_id.0 as u32,
+ collaborators: collaborators.clone(),
+ language_servers: project.language_servers.clone(),
+ })?;
- async fn update_followers(
- self: Arc<Self>,
- request: proto::UpdateFollowers,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = session
- .db
- .lock()
- .await
- .project_connection_ids(project_id, session.connection_id)
- .await?;
+ for (worktree_id, worktree) in project.worktrees {
+ #[cfg(any(test, feature = "test-support"))]
+ const MAX_CHUNK_SIZE: usize = 2;
+ #[cfg(not(any(test, feature = "test-support")))]
+ const MAX_CHUNK_SIZE: usize = 256;
- let leader_id = request.variant.as_ref().and_then(|variant| match variant {
- proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
- proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
- proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
- });
- for follower_id in &request.follower_ids {
- let follower_id = ConnectionId(*follower_id);
- if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
- self.peer
- .forward_send(session.connection_id, follower_id, request.clone())?;
- }
+ // Stream this worktree's entries.
+ let message = proto::UpdateWorktree {
+ project_id: project_id.to_proto(),
+ worktree_id: worktree_id.to_proto(),
+ abs_path: worktree.abs_path.clone(),
+ root_name: worktree.root_name,
+ updated_entries: worktree.entries,
+ removed_entries: Default::default(),
+ scan_id: worktree.scan_id,
+ is_last_update: worktree.is_complete,
+ };
+ for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
+ session.peer.send(session.connection_id, update.clone())?;
}
- Ok(())
- }
- async fn get_users(
- self: Arc<Server>,
- request: proto::GetUsers,
- response: Response<proto::GetUsers>,
- _session: Session,
- ) -> Result<()> {
- let user_ids = request
- .user_ids
- .into_iter()
- .map(UserId::from_proto)
- .collect();
- let users = self
- .app_state
- .db
- .get_users_by_ids(user_ids)
- .await?
- .into_iter()
- .map(|user| proto::User {
- id: user.id.to_proto(),
- avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
- github_login: user.github_login,
- })
- .collect();
- response.send(proto::UsersResponse { users })?;
- Ok(())
+ // Stream this worktree's diagnostics.
+ for summary in worktree.diagnostic_summaries {
+ session.peer.send(
+ session.connection_id,
+ proto::UpdateDiagnosticSummary {
+ project_id: project_id.to_proto(),
+ worktree_id: worktree.id.to_proto(),
+ summary: Some(summary),
+ },
+ )?;
+ }
}
- async fn fuzzy_search_users(
- self: Arc<Server>,
- request: proto::FuzzySearchUsers,
- response: Response<proto::FuzzySearchUsers>,
- session: Session,
- ) -> Result<()> {
- let query = request.query;
- let db = &self.app_state.db;
- let users = match query.len() {
- 0 => vec![],
- 1 | 2 => db
- .get_user_by_github_account(&query, None)
- .await?
- .into_iter()
- .collect(),
- _ => db.fuzzy_search_users(&query, 10).await?,
- };
- let users = users
- .into_iter()
- .filter(|user| user.id != session.user_id)
- .map(|user| proto::User {
- id: user.id.to_proto(),
- avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
- github_login: user.github_login,
- })
- .collect();
- response.send(proto::UsersResponse { users })?;
- Ok(())
+ for language_server in &project.language_servers {
+ session.peer.send(
+ session.connection_id,
+ proto::UpdateLanguageServer {
+ project_id: project_id.to_proto(),
+ language_server_id: language_server.id,
+ variant: Some(
+ proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
+ proto::LspDiskBasedDiagnosticsUpdated {},
+ ),
+ ),
+ },
+ )?;
}
- async fn request_contact(
- self: Arc<Server>,
- request: proto::RequestContact,
- response: Response<proto::RequestContact>,
- session: Session,
- ) -> Result<()> {
- let requester_id = session.user_id;
- let responder_id = UserId::from_proto(request.responder_id);
- if requester_id == responder_id {
- return Err(anyhow!("cannot add yourself as a contact"))?;
- }
-
- self.app_state
- .db
- .send_contact_request(requester_id, responder_id)
- .await?;
-
- // Update outgoing contact requests of requester
- let mut update = proto::UpdateContacts::default();
- update.outgoing_requests.push(responder_id.to_proto());
- for connection_id in self
- .connection_pool()
- .await
- .user_connection_ids(requester_id)
- {
- self.peer.send(connection_id, update.clone())?;
- }
+ Ok(())
+}
- // Update incoming contact requests of responder
- let mut update = proto::UpdateContacts::default();
- update
- .incoming_requests
- .push(proto::IncomingContactRequest {
- requester_id: requester_id.to_proto(),
- should_notify: true,
- });
- for connection_id in self
- .connection_pool()
+async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
+ let sender_id = session.connection_id;
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project;
+ {
+ project = session
+ .db()
.await
- .user_connection_ids(responder_id)
- {
- self.peer.send(connection_id, update.clone())?;
- }
+ .leave_project(project_id, sender_id)
+ .await?;
+ tracing::info!(
+ %project_id,
+ host_user_id = %project.host_user_id,
+ host_connection_id = %project.host_connection_id,
+ "leave project"
+ );
- response.send(proto::Ack {})?;
- Ok(())
+ broadcast(sender_id, project.connection_ids, |conn_id| {
+ session.peer.send(
+ conn_id,
+ proto::RemoveProjectCollaborator {
+ project_id: project_id.to_proto(),
+ peer_id: sender_id.0,
+ },
+ )
+ });
}
- async fn respond_to_contact_request(
- self: Arc<Server>,
- request: proto::RespondToContactRequest,
- response: Response<proto::RespondToContactRequest>,
- session: Session,
- ) -> Result<()> {
- let responder_id = session.user_id;
- let requester_id = UserId::from_proto(request.requester_id);
- if request.response == proto::ContactRequestResponse::Dismiss as i32 {
- self.app_state
- .db
- .dismiss_contact_notification(responder_id, requester_id)
- .await?;
- } else {
- let accept = request.response == proto::ContactRequestResponse::Accept as i32;
- self.app_state
- .db
- .respond_to_contact_request(responder_id, requester_id, accept)
- .await?;
- let busy = self.app_state.db.is_user_busy(requester_id).await?;
-
- let pool = self.connection_pool().await;
- // Update responder with new contact
- let mut update = proto::UpdateContacts::default();
- if accept {
- update
- .contacts
- .push(contact_for_user(requester_id, false, busy, &pool));
- }
- update
- .remove_incoming_requests
- .push(requester_id.to_proto());
- for connection_id in pool.user_connection_ids(responder_id) {
- self.peer.send(connection_id, update.clone())?;
- }
+ Ok(())
+}
- // Update requester with new contact
- let mut update = proto::UpdateContacts::default();
- if accept {
- update
- .contacts
- .push(contact_for_user(responder_id, true, busy, &pool));
- }
- update
- .remove_outgoing_requests
- .push(responder_id.to_proto());
- for connection_id in pool.user_connection_ids(requester_id) {
- self.peer.send(connection_id, update.clone())?;
- }
- }
+async fn update_project(
+ request: proto::UpdateProject,
+ response: Response<proto::UpdateProject>,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let (room, guest_connection_ids) = session
+ .db()
+ .await
+ .update_project(project_id, session.connection_id, &request.worktrees)
+ .await?;
+ broadcast(
+ session.connection_id,
+ guest_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ room_updated(&room, &session);
+ response.send(proto::Ack {})?;
- response.send(proto::Ack {})?;
- Ok(())
- }
+ Ok(())
+}
- async fn remove_contact(
- self: Arc<Server>,
- request: proto::RemoveContact,
- response: Response<proto::RemoveContact>,
- session: Session,
- ) -> Result<()> {
- let requester_id = session.user_id;
- let responder_id = UserId::from_proto(request.user_id);
- self.app_state
- .db
- .remove_contact(requester_id, responder_id)
- .await?;
+async fn update_worktree(
+ request: proto::UpdateWorktree,
+ response: Response<proto::UpdateWorktree>,
+ session: Session,
+) -> Result<()> {
+ let guest_connection_ids = session
+ .db()
+ .await
+ .update_worktree(&request, session.connection_id)
+ .await?;
+
+ broadcast(
+ session.connection_id,
+ guest_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ response.send(proto::Ack {})?;
+ Ok(())
+}
- // Update outgoing contact requests of requester
- let mut update = proto::UpdateContacts::default();
- update
- .remove_outgoing_requests
- .push(responder_id.to_proto());
- for connection_id in self
- .connection_pool()
- .await
- .user_connection_ids(requester_id)
- {
- self.peer.send(connection_id, update.clone())?;
- }
+async fn update_diagnostic_summary(
+ request: proto::UpdateDiagnosticSummary,
+ response: Response<proto::UpdateDiagnosticSummary>,
+ session: Session,
+) -> Result<()> {
+ let guest_connection_ids = session
+ .db()
+ .await
+ .update_diagnostic_summary(&request, session.connection_id)
+ .await?;
+
+ broadcast(
+ session.connection_id,
+ guest_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
- // Update incoming contact requests of responder
- let mut update = proto::UpdateContacts::default();
- update
- .remove_incoming_requests
- .push(requester_id.to_proto());
- for connection_id in self
- .connection_pool()
- .await
- .user_connection_ids(responder_id)
- {
- self.peer.send(connection_id, update.clone())?;
- }
+ response.send(proto::Ack {})?;
+ Ok(())
+}
- response.send(proto::Ack {})?;
- Ok(())
- }
+async fn start_language_server(
+ request: proto::StartLanguageServer,
+ session: Session,
+) -> Result<()> {
+ let guest_connection_ids = session
+ .db()
+ .await
+ .start_language_server(&request, session.connection_id)
+ .await?;
+
+ broadcast(
+ session.connection_id,
+ guest_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
+}
- async fn update_diff_base(
- self: Arc<Server>,
- request: proto::UpdateDiffBase,
- session: Session,
- ) -> Result<()> {
- let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
- .db
- .project_connection_ids(project_id, session.connection_id)
- .await?;
- broadcast(
- session.connection_id,
- project_connection_ids,
- |connection_id| {
- self.peer
- .forward_send(session.connection_id, connection_id, request.clone())
- },
- );
- Ok(())
- }
+async fn update_language_server(
+ request: proto::UpdateLanguageServer,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
+}
- async fn get_private_user_info(
- self: Arc<Self>,
- _request: proto::GetPrivateUserInfo,
- response: Response<proto::GetPrivateUserInfo>,
- session: Session,
- ) -> Result<()> {
- let metrics_id = self
- .app_state
- .db
- .get_user_metrics_id(session.user_id)
- .await?;
- let user = self
- .app_state
- .db
- .get_user_by_id(session.user_id)
- .await?
- .ok_or_else(|| anyhow!("user not found"))?;
- response.send(proto::GetPrivateUserInfoResponse {
- metrics_id,
- staff: user.admin,
- })?;
- Ok(())
- }
+async fn forward_project_request<T>(
+ request: T,
+ response: Response<T>,
+ session: Session,
+) -> Result<()>
+where
+ T: EntityMessage + RequestMessage,
+{
+ let project_id = ProjectId::from_proto(request.remote_entity_id());
+ let collaborators = session
+ .db()
+ .await
+ .project_collaborators(project_id, session.connection_id)
+ .await?;
+ let host = collaborators
+ .iter()
+ .find(|collaborator| collaborator.is_host)
+ .ok_or_else(|| anyhow!("host not found"))?;
+
+ let payload = session
+ .peer
+ .forward_request(
+ session.connection_id,
+ ConnectionId(host.connection_id as u32),
+ request,
+ )
+ .await?;
- pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
- #[cfg(test)]
- tokio::task::yield_now().await;
- let guard = self.connection_pool.lock().await;
- #[cfg(test)]
- tokio::task::yield_now().await;
- ConnectionPoolGuard {
- guard,
- _not_send: PhantomData,
- }
- }
+ response.send(payload)?;
+ Ok(())
+}
- pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
- ServerSnapshot {
- connection_pool: self.connection_pool().await,
- peer: &self.peer,
- }
- }
+async fn save_buffer(
+ request: proto::SaveBuffer,
+ response: Response<proto::SaveBuffer>,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let collaborators = session
+ .db()
+ .await
+ .project_collaborators(project_id, session.connection_id)
+ .await?;
+ let host = collaborators
+ .into_iter()
+ .find(|collaborator| collaborator.is_host)
+ .ok_or_else(|| anyhow!("host not found"))?;
+ let host_connection_id = ConnectionId(host.connection_id as u32);
+ let response_payload = session
+ .peer
+ .forward_request(session.connection_id, host_connection_id, request.clone())
+ .await?;
+
+ let mut collaborators = session
+ .db()
+ .await
+ .project_collaborators(project_id, session.connection_id)
+ .await?;
+ collaborators
+ .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
+ let project_connection_ids = collaborators
+ .into_iter()
+ .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
+ broadcast(host_connection_id, project_connection_ids, |conn_id| {
+ session
+ .peer
+ .forward_send(host_connection_id, conn_id, response_payload.clone())
+ });
+ response.send(response_payload)?;
+ Ok(())
}
-impl<'a> Deref for ConnectionPoolGuard<'a> {
- type Target = ConnectionPool;
+async fn create_buffer_for_peer(
+ request: proto::CreateBufferForPeer,
+ session: Session,
+) -> Result<()> {
+ session.peer.forward_send(
+ session.connection_id,
+ ConnectionId(request.peer_id),
+ request,
+ )?;
+ Ok(())
+}
- fn deref(&self) -> &Self::Target {
- &*self.guard
- }
+async fn update_buffer(
+ request: proto::UpdateBuffer,
+ response: Response<proto::UpdateBuffer>,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ response.send(proto::Ack {})?;
+ Ok(())
}
-impl<'a> DerefMut for ConnectionPoolGuard<'a> {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut *self.guard
- }
+async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
}
-impl<'a> Drop for ConnectionPoolGuard<'a> {
- fn drop(&mut self) {
- #[cfg(test)]
- self.check_invariants();
- }
+async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
}
-impl Executor for RealExecutor {
- type Sleep = Sleep;
+async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
+}
- fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
- tokio::task::spawn(future);
- }
+async fn follow(
+ request: proto::Follow,
+ response: Response<proto::Follow>,
+ session: Session,
+) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let leader_id = ConnectionId(request.leader_id);
+ let follower_id = session.connection_id;
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+
+ if !project_connection_ids.contains(&leader_id) {
+ Err(anyhow!("no such peer"))?;
+ }
+
+ let mut response_payload = session
+ .peer
+ .forward_request(session.connection_id, leader_id, request)
+ .await?;
+ response_payload
+ .views
+ .retain(|view| view.leader_id != Some(follower_id.0));
+ response.send(response_payload)?;
+ Ok(())
+}
- fn sleep(&self, duration: Duration) -> Self::Sleep {
- tokio::time::sleep(duration)
- }
+async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let leader_id = ConnectionId(request.leader_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+ if !project_connection_ids.contains(&leader_id) {
+ Err(anyhow!("no such peer"))?;
+ }
+ session
+ .peer
+ .forward_send(session.connection_id, leader_id, request)?;
+ Ok(())
}
-fn broadcast<F>(
- sender_id: ConnectionId,
- receiver_ids: impl IntoIterator<Item = ConnectionId>,
- mut f: F,
-) where
- F: FnMut(ConnectionId) -> anyhow::Result<()>,
-{
- for receiver_id in receiver_ids {
- if receiver_id != sender_id {
- f(receiver_id).trace_err();
+async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db
+ .lock()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+
+ let leader_id = request.variant.as_ref().and_then(|variant| match variant {
+ proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
+ proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
+ proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
+ });
+ for follower_id in &request.follower_ids {
+ let follower_id = ConnectionId(*follower_id);
+ if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
+ session
+ .peer
+ .forward_send(session.connection_id, follower_id, request.clone())?;
}
}
+ Ok(())
}
-lazy_static! {
- static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
+async fn get_users(
+ request: proto::GetUsers,
+ response: Response<proto::GetUsers>,
+ session: Session,
+) -> Result<()> {
+ let user_ids = request
+ .user_ids
+ .into_iter()
+ .map(UserId::from_proto)
+ .collect();
+ let users = session
+ .db()
+ .await
+ .get_users_by_ids(user_ids)
+ .await?
+ .into_iter()
+ .map(|user| proto::User {
+ id: user.id.to_proto(),
+ avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
+ github_login: user.github_login,
+ })
+ .collect();
+ response.send(proto::UsersResponse { users })?;
+ Ok(())
}
-pub struct ProtocolVersion(u32);
+async fn fuzzy_search_users(
+ request: proto::FuzzySearchUsers,
+ response: Response<proto::FuzzySearchUsers>,
+ session: Session,
+) -> Result<()> {
+ let query = request.query;
+ let users = match query.len() {
+ 0 => vec![],
+ 1 | 2 => session
+ .db()
+ .await
+ .get_user_by_github_account(&query, None)
+ .await?
+ .into_iter()
+ .collect(),
+ _ => session.db().await.fuzzy_search_users(&query, 10).await?,
+ };
+ let users = users
+ .into_iter()
+ .filter(|user| user.id != session.user_id)
+ .map(|user| proto::User {
+ id: user.id.to_proto(),
+ avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
+ github_login: user.github_login,
+ })
+ .collect();
+ response.send(proto::UsersResponse { users })?;
+ Ok(())
+}
-impl Header for ProtocolVersion {
- fn name() -> &'static HeaderName {
- &ZED_PROTOCOL_VERSION
+async fn request_contact(
+ request: proto::RequestContact,
+ response: Response<proto::RequestContact>,
+ session: Session,
+) -> Result<()> {
+ let requester_id = session.user_id;
+ let responder_id = UserId::from_proto(request.responder_id);
+ if requester_id == responder_id {
+ return Err(anyhow!("cannot add yourself as a contact"))?;
}
- fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
- where
- Self: Sized,
- I: Iterator<Item = &'i axum::http::HeaderValue>,
+ session
+ .db()
+ .await
+ .send_contact_request(requester_id, responder_id)
+ .await?;
+
+ // Update outgoing contact requests of requester
+ let mut update = proto::UpdateContacts::default();
+ update.outgoing_requests.push(responder_id.to_proto());
+ for connection_id in session
+ .connection_pool()
+ .await
+ .user_connection_ids(requester_id)
{
- let version = values
- .next()
- .ok_or_else(axum::headers::Error::invalid)?
- .to_str()
- .map_err(|_| axum::headers::Error::invalid())?
- .parse()
- .map_err(|_| axum::headers::Error::invalid())?;
- Ok(Self(version))
+ session.peer.send(connection_id, update.clone())?;
}
- fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
- values.extend([self.0.to_string().parse().unwrap()]);
+ // Update incoming contact requests of responder
+ let mut update = proto::UpdateContacts::default();
+ update
+ .incoming_requests
+ .push(proto::IncomingContactRequest {
+ requester_id: requester_id.to_proto(),
+ should_notify: true,
+ });
+ for connection_id in session
+ .connection_pool()
+ .await
+ .user_connection_ids(responder_id)
+ {
+ session.peer.send(connection_id, update.clone())?;
}
-}
-pub fn routes(server: Arc<Server>) -> Router<Body> {
- Router::new()
- .route("/rpc", get(handle_websocket_request))
- .layer(
- ServiceBuilder::new()
- .layer(Extension(server.app_state.clone()))
- .layer(middleware::from_fn(auth::validate_header)),
- )
- .route("/metrics", get(handle_metrics))
- .layer(Extension(server))
+ response.send(proto::Ack {})?;
+ Ok(())
}
-pub async fn handle_websocket_request(
- TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
- ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
- Extension(server): Extension<Arc<Server>>,
- Extension(user): Extension<User>,
- ws: WebSocketUpgrade,
-) -> axum::response::Response {
- if protocol_version != rpc::PROTOCOL_VERSION {
- return (
- StatusCode::UPGRADE_REQUIRED,
- "client must be upgraded".to_string(),
- )
- .into_response();
- }
- let socket_address = socket_address.to_string();
- ws.on_upgrade(move |socket| {
- use util::ResultExt;
- let socket = socket
- .map_ok(to_tungstenite_message)
- .err_into()
- .with(|message| async move { Ok(to_axum_message(message)) });
- let connection = Connection::new(Box::pin(socket));
- async move {
- server
- .handle_connection(connection, socket_address, user, None, RealExecutor)
- .await
- .log_err();
+async fn respond_to_contact_request(
+ request: proto::RespondToContactRequest,
+ response: Response<proto::RespondToContactRequest>,
+ session: Session,
+) -> Result<()> {
+ let responder_id = session.user_id;
+ let requester_id = UserId::from_proto(request.requester_id);
+ let db = session.db().await;
+ if request.response == proto::ContactRequestResponse::Dismiss as i32 {
+ db.dismiss_contact_notification(responder_id, requester_id)
+ .await?;
+ } else {
+ let accept = request.response == proto::ContactRequestResponse::Accept as i32;
+
+ db.respond_to_contact_request(responder_id, requester_id, accept)
+ .await?;
+ let busy = db.is_user_busy(requester_id).await?;
+
+ let pool = session.connection_pool().await;
+ // Update responder with new contact
+ let mut update = proto::UpdateContacts::default();
+ if accept {
+ update
+ .contacts
+ .push(contact_for_user(requester_id, false, busy, &pool));
}
- })
+ update
+ .remove_incoming_requests
+ .push(requester_id.to_proto());
+ for connection_id in pool.user_connection_ids(responder_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
+
+ // Update requester with new contact
+ let mut update = proto::UpdateContacts::default();
+ if accept {
+ update
+ .contacts
+ .push(contact_for_user(responder_id, true, busy, &pool));
+ }
+ update
+ .remove_outgoing_requests
+ .push(responder_id.to_proto());
+ for connection_id in pool.user_connection_ids(requester_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
+ }
+
+ response.send(proto::Ack {})?;
+ Ok(())
}
-pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
- let connections = server
- .connection_pool()
- .await
- .connections()
- .filter(|connection| !connection.admin)
- .count();
+async fn remove_contact(
+ request: proto::RemoveContact,
+ response: Response<proto::RemoveContact>,
+ session: Session,
+) -> Result<()> {
+ let requester_id = session.user_id;
+ let responder_id = UserId::from_proto(request.user_id);
+ let db = session.db().await;
+ db.remove_contact(requester_id, responder_id).await?;
+
+ let pool = session.connection_pool().await;
+ // Update outgoing contact requests of requester
+ let mut update = proto::UpdateContacts::default();
+ update
+ .remove_outgoing_requests
+ .push(responder_id.to_proto());
+ for connection_id in pool.user_connection_ids(requester_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
- METRIC_CONNECTIONS.set(connections as _);
+ // Update incoming contact requests of responder
+ let mut update = proto::UpdateContacts::default();
+ update
+ .remove_incoming_requests
+ .push(requester_id.to_proto());
+ for connection_id in pool.user_connection_ids(responder_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
- let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
- METRIC_SHARED_PROJECTS.set(shared_projects as _);
+ response.send(proto::Ack {})?;
+ Ok(())
+}
- let encoder = prometheus::TextEncoder::new();
- let metric_families = prometheus::gather();
- let encoded_metrics = encoder
- .encode_to_string(&metric_families)
- .map_err(|err| anyhow!("{}", err))?;
- Ok(encoded_metrics)
+async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
+ let project_id = ProjectId::from_proto(request.project_id);
+ let project_connection_ids = session
+ .db()
+ .await
+ .project_connection_ids(project_id, session.connection_id)
+ .await?;
+ broadcast(
+ session.connection_id,
+ project_connection_ids,
+ |connection_id| {
+ session
+ .peer
+ .forward_send(session.connection_id, connection_id, request.clone())
+ },
+ );
+ Ok(())
+}
+
+async fn get_private_user_info(
+ _request: proto::GetPrivateUserInfo,
+ response: Response<proto::GetPrivateUserInfo>,
+ session: Session,
+) -> Result<()> {
+ let metrics_id = session
+ .db()
+ .await
+ .get_user_metrics_id(session.user_id)
+ .await?;
+ let user = session
+ .db()
+ .await
+ .get_user_by_id(session.user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+ response.send(proto::GetPrivateUserInfoResponse {
+ metrics_id,
+ staff: user.admin,
+ })?;
+ Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {