rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    os::unix::prelude::OsStrExt,
  46    rc::Rc,
  47    sync::{
  48        atomic::{AtomicBool, Ordering::SeqCst},
  49        Arc,
  50    },
  51    time::Duration,
  52};
  53pub use store::{Store, Worktree};
  54use tokio::{
  55    sync::{Mutex, MutexGuard},
  56    time::Sleep,
  57};
  58use tower::ServiceBuilder;
  59use tracing::{info_span, instrument, Instrument};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  65        "shared_projects",
  66        "number of open projects with one or more guests"
  67    )
  68    .unwrap();
  69}
  70
  71type MessageHandler = Box<
  72    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  73>;
  74
  75struct Message<T> {
  76    sender_user_id: UserId,
  77    sender_connection_id: ConnectionId,
  78    payload: T,
  79}
  80
  81struct Response<R> {
  82    server: Arc<Server>,
  83    receipt: Receipt<R>,
  84    responded: Arc<AtomicBool>,
  85}
  86
  87impl<R: RequestMessage> Response<R> {
  88    fn send(self, payload: R::Response) -> Result<()> {
  89        self.responded.store(true, SeqCst);
  90        self.server.peer.respond(self.receipt, payload)?;
  91        Ok(())
  92    }
  93}
  94
  95pub struct Server {
  96    peer: Arc<Peer>,
  97    pub(crate) store: Mutex<Store>,
  98    app_state: Arc<AppState>,
  99    handlers: HashMap<TypeId, MessageHandler>,
 100}
 101
 102pub trait Executor: Send + Clone {
 103    type Sleep: Send + Future;
 104    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 105    fn sleep(&self, duration: Duration) -> Self::Sleep;
 106}
 107
 108#[derive(Clone)]
 109pub struct RealExecutor;
 110
 111pub(crate) struct StoreGuard<'a> {
 112    guard: MutexGuard<'a, Store>,
 113    _not_send: PhantomData<Rc<()>>,
 114}
 115
 116#[derive(Serialize)]
 117pub struct ServerSnapshot<'a> {
 118    peer: &'a Peer,
 119    #[serde(serialize_with = "serialize_deref")]
 120    store: StoreGuard<'a>,
 121}
 122
 123pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 124where
 125    S: Serializer,
 126    T: Deref<Target = U>,
 127    U: Serialize,
 128{
 129    Serialize::serialize(value.deref(), serializer)
 130}
 131
 132impl Server {
 133    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 134        let mut server = Self {
 135            peer: Peer::new(),
 136            app_state,
 137            store: Default::default(),
 138            handlers: Default::default(),
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::create_room)
 144            .add_request_handler(Server::join_room)
 145            .add_message_handler(Server::leave_room)
 146            .add_request_handler(Server::call)
 147            .add_request_handler(Server::cancel_call)
 148            .add_message_handler(Server::decline_call)
 149            .add_request_handler(Server::update_participant_location)
 150            .add_request_handler(Server::share_project)
 151            .add_message_handler(Server::unshare_project)
 152            .add_request_handler(Server::join_project)
 153            .add_message_handler(Server::leave_project)
 154            .add_message_handler(Server::update_project)
 155            .add_request_handler(Server::update_worktree)
 156            .add_message_handler(Server::start_language_server)
 157            .add_message_handler(Server::update_language_server)
 158            .add_message_handler(Server::update_diagnostic_summary)
 159            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 163            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 165            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 168            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 169            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 170            .add_request_handler(
 171                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 172            )
 173            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 174            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 175            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 177            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 179            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 182            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 183            .add_message_handler(Server::create_buffer_for_peer)
 184            .add_request_handler(Server::update_buffer)
 185            .add_message_handler(Server::update_buffer_file)
 186            .add_message_handler(Server::buffer_reloaded)
 187            .add_message_handler(Server::buffer_saved)
 188            .add_request_handler(Server::save_buffer)
 189            .add_request_handler(Server::get_users)
 190            .add_request_handler(Server::fuzzy_search_users)
 191            .add_request_handler(Server::request_contact)
 192            .add_request_handler(Server::remove_contact)
 193            .add_request_handler(Server::respond_to_contact_request)
 194            .add_request_handler(Server::follow)
 195            .add_message_handler(Server::unfollow)
 196            .add_message_handler(Server::update_followers)
 197            .add_message_handler(Server::update_diff_base)
 198            .add_request_handler(Server::get_private_user_info);
 199
 200        Arc::new(server)
 201    }
 202
 203    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 204    where
 205        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 206        Fut: 'static + Send + Future<Output = Result<()>>,
 207        M: EnvelopedMessage,
 208    {
 209        let prev_handler = self.handlers.insert(
 210            TypeId::of::<M>(),
 211            Box::new(move |server, sender_user_id, envelope| {
 212                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 213                let span = info_span!(
 214                    "handle message",
 215                    payload_type = envelope.payload_type_name()
 216                );
 217                span.in_scope(|| {
 218                    tracing::info!(
 219                        payload_type = envelope.payload_type_name(),
 220                        "message received"
 221                    );
 222                });
 223                let future = (handler)(server, sender_user_id, *envelope);
 224                async move {
 225                    if let Err(error) = future.await {
 226                        tracing::error!(%error, "error handling message");
 227                    }
 228                }
 229                .instrument(span)
 230                .boxed()
 231            }),
 232        );
 233        if prev_handler.is_some() {
 234            panic!("registered a handler for the same message twice");
 235        }
 236        self
 237    }
 238
 239    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 240    where
 241        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 242        Fut: 'static + Send + Future<Output = Result<()>>,
 243        M: EnvelopedMessage,
 244    {
 245        self.add_handler(move |server, sender_user_id, envelope| {
 246            handler(
 247                server,
 248                Message {
 249                    sender_user_id,
 250                    sender_connection_id: envelope.sender_id,
 251                    payload: envelope.payload,
 252                },
 253            )
 254        });
 255        self
 256    }
 257
 258    /// Handle a request while holding a lock to the store. This is useful when we're registering
 259    /// a connection but we want to respond on the connection before anybody else can send on it.
 260    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 261    where
 262        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 263        Fut: Send + Future<Output = Result<()>>,
 264        M: RequestMessage,
 265    {
 266        let handler = Arc::new(handler);
 267        self.add_handler(move |server, sender_user_id, envelope| {
 268            let receipt = envelope.receipt();
 269            let handler = handler.clone();
 270            async move {
 271                let request = Message {
 272                    sender_user_id,
 273                    sender_connection_id: envelope.sender_id,
 274                    payload: envelope.payload,
 275                };
 276                let responded = Arc::new(AtomicBool::default());
 277                let response = Response {
 278                    server: server.clone(),
 279                    responded: responded.clone(),
 280                    receipt,
 281                };
 282                match (handler)(server.clone(), request, response).await {
 283                    Ok(()) => {
 284                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 285                            Ok(())
 286                        } else {
 287                            Err(anyhow!("handler did not send a response"))?
 288                        }
 289                    }
 290                    Err(error) => {
 291                        server.peer.respond_with_error(
 292                            receipt,
 293                            proto::Error {
 294                                message: error.to_string(),
 295                            },
 296                        )?;
 297                        Err(error)
 298                    }
 299                }
 300            }
 301        })
 302    }
 303
 304    pub fn handle_connection<E: Executor>(
 305        self: &Arc<Self>,
 306        connection: Connection,
 307        address: String,
 308        user: User,
 309        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 310        executor: E,
 311    ) -> impl Future<Output = Result<()>> {
 312        let mut this = self.clone();
 313        let user_id = user.id;
 314        let login = user.github_login;
 315        let span = info_span!("handle connection", %user_id, %login, %address);
 316        async move {
 317            let (connection_id, handle_io, mut incoming_rx) = this
 318                .peer
 319                .add_connection(connection, {
 320                    let executor = executor.clone();
 321                    move |duration| {
 322                        let timer = executor.sleep(duration);
 323                        async move {
 324                            timer.await;
 325                        }
 326                    }
 327                });
 328
 329            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 330            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 331            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 332
 333            if let Some(send_connection_id) = send_connection_id.take() {
 334                let _ = send_connection_id.send(connection_id);
 335            }
 336
 337            if !user.connected_once {
 338                this.peer.send(connection_id, proto::ShowContacts {})?;
 339                this.app_state.db.set_user_connected_once(user_id, true).await?;
 340            }
 341
 342            let (contacts, invite_code) = future::try_join(
 343                this.app_state.db.get_contacts(user_id),
 344                this.app_state.db.get_invite_code_for_user(user_id)
 345            ).await?;
 346
 347            {
 348                let mut store = this.store().await;
 349                let incoming_call = store.add_connection(connection_id, user_id, user.admin);
 350                if let Some(incoming_call) = incoming_call {
 351                    this.peer.send(connection_id, incoming_call)?;
 352                }
 353
 354                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 355
 356                if let Some((code, count)) = invite_code {
 357                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 358                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 359                        count,
 360                    })?;
 361                }
 362            }
 363            this.update_user_contacts(user_id).await?;
 364
 365            let handle_io = handle_io.fuse();
 366            futures::pin_mut!(handle_io);
 367
 368            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 369            // This prevents deadlocks when e.g., client A performs a request to client B and
 370            // client B performs a request to client A. If both clients stop processing further
 371            // messages until their respective request completes, they won't have a chance to
 372            // respond to the other client's request and cause a deadlock.
 373            //
 374            // This arrangement ensures we will attempt to process earlier messages first, but fall
 375            // back to processing messages arrived later in the spirit of making progress.
 376            let mut foreground_message_handlers = FuturesUnordered::new();
 377            loop {
 378                let next_message = incoming_rx.next().fuse();
 379                futures::pin_mut!(next_message);
 380                futures::select_biased! {
 381                    result = handle_io => {
 382                        if let Err(error) = result {
 383                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 384                        }
 385                        break;
 386                    }
 387                    _ = foreground_message_handlers.next() => {}
 388                    message = next_message => {
 389                        if let Some(message) = message {
 390                            let type_name = message.payload_type_name();
 391                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 392                            let span_enter = span.enter();
 393                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 394                                let is_background = message.is_background();
 395                                let handle_message = (handler)(this.clone(), user_id, message);
 396                                drop(span_enter);
 397
 398                                let handle_message = handle_message.instrument(span);
 399                                if is_background {
 400                                    executor.spawn_detached(handle_message);
 401                                } else {
 402                                    foreground_message_handlers.push(handle_message);
 403                                }
 404                            } else {
 405                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 406                            }
 407                        } else {
 408                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 409                            break;
 410                        }
 411                    }
 412                }
 413            }
 414
 415            drop(foreground_message_handlers);
 416            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 417            if let Err(error) = this.sign_out(connection_id).await {
 418                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 419            }
 420
 421            Ok(())
 422        }.instrument(span)
 423    }
 424
 425    #[instrument(skip(self), err)]
 426    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 427        self.peer.disconnect(connection_id);
 428
 429        let mut projects_to_unshare = Vec::new();
 430        let mut contacts_to_update = HashSet::default();
 431        let mut room_left = None;
 432        {
 433            let mut store = self.store().await;
 434
 435            #[cfg(test)]
 436            let removed_connection = store.remove_connection(connection_id).unwrap();
 437            #[cfg(not(test))]
 438            let removed_connection = store.remove_connection(connection_id)?;
 439
 440            for project in removed_connection.hosted_projects {
 441                projects_to_unshare.push(project.id);
 442                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 443                    self.peer.send(
 444                        conn_id,
 445                        proto::UnshareProject {
 446                            project_id: project.id.to_proto(),
 447                        },
 448                    )
 449                });
 450            }
 451
 452            for project in removed_connection.guest_projects {
 453                broadcast(connection_id, project.connection_ids, |conn_id| {
 454                    self.peer.send(
 455                        conn_id,
 456                        proto::RemoveProjectCollaborator {
 457                            project_id: project.id.to_proto(),
 458                            peer_id: connection_id.0,
 459                        },
 460                    )
 461                });
 462            }
 463
 464            if let Some(room) = removed_connection.room {
 465                self.room_updated(&room);
 466                room_left = Some(self.room_left(&room, connection_id));
 467            }
 468
 469            contacts_to_update.insert(removed_connection.user_id);
 470            for connection_id in removed_connection.canceled_call_connection_ids {
 471                self.peer
 472                    .send(connection_id, proto::CallCanceled {})
 473                    .trace_err();
 474                contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
 475            }
 476        };
 477
 478        if let Some(room_left) = room_left {
 479            room_left.await.trace_err();
 480        }
 481
 482        for user_id in contacts_to_update {
 483            self.update_user_contacts(user_id).await.trace_err();
 484        }
 485
 486        for project_id in projects_to_unshare {
 487            self.app_state
 488                .db
 489                .unshare_project(project_id)
 490                .await
 491                .trace_err();
 492        }
 493
 494        Ok(())
 495    }
 496
 497    pub async fn invite_code_redeemed(
 498        self: &Arc<Self>,
 499        inviter_id: UserId,
 500        invitee_id: UserId,
 501    ) -> Result<()> {
 502        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 503            if let Some(code) = &user.invite_code {
 504                let store = self.store().await;
 505                let invitee_contact = store.contact_for_user(invitee_id, true);
 506                for connection_id in store.connection_ids_for_user(inviter_id) {
 507                    self.peer.send(
 508                        connection_id,
 509                        proto::UpdateContacts {
 510                            contacts: vec![invitee_contact.clone()],
 511                            ..Default::default()
 512                        },
 513                    )?;
 514                    self.peer.send(
 515                        connection_id,
 516                        proto::UpdateInviteInfo {
 517                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 518                            count: user.invite_count as u32,
 519                        },
 520                    )?;
 521                }
 522            }
 523        }
 524        Ok(())
 525    }
 526
 527    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 528        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 529            if let Some(invite_code) = &user.invite_code {
 530                let store = self.store().await;
 531                for connection_id in store.connection_ids_for_user(user_id) {
 532                    self.peer.send(
 533                        connection_id,
 534                        proto::UpdateInviteInfo {
 535                            url: format!(
 536                                "{}{}",
 537                                self.app_state.config.invite_link_prefix, invite_code
 538                            ),
 539                            count: user.invite_count as u32,
 540                        },
 541                    )?;
 542                }
 543            }
 544        }
 545        Ok(())
 546    }
 547
 548    async fn ping(
 549        self: Arc<Server>,
 550        _: Message<proto::Ping>,
 551        response: Response<proto::Ping>,
 552    ) -> Result<()> {
 553        response.send(proto::Ack {})?;
 554        Ok(())
 555    }
 556
 557    async fn create_room(
 558        self: Arc<Server>,
 559        request: Message<proto::CreateRoom>,
 560        response: Response<proto::CreateRoom>,
 561    ) -> Result<()> {
 562        let room = self
 563            .app_state
 564            .db
 565            .create_room(request.sender_user_id, request.sender_connection_id)
 566            .await?;
 567
 568        let live_kit_connection_info =
 569            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 570                if let Some(_) = live_kit
 571                    .create_room(room.live_kit_room.clone())
 572                    .await
 573                    .trace_err()
 574                {
 575                    if let Some(token) = live_kit
 576                        .room_token(
 577                            &room.live_kit_room,
 578                            &request.sender_connection_id.to_string(),
 579                        )
 580                        .trace_err()
 581                    {
 582                        Some(proto::LiveKitConnectionInfo {
 583                            server_url: live_kit.url().into(),
 584                            token,
 585                        })
 586                    } else {
 587                        None
 588                    }
 589                } else {
 590                    None
 591                }
 592            } else {
 593                None
 594            };
 595
 596        response.send(proto::CreateRoomResponse {
 597            room: Some(room),
 598            live_kit_connection_info,
 599        })?;
 600        self.update_user_contacts(request.sender_user_id).await?;
 601        Ok(())
 602    }
 603
 604    async fn join_room(
 605        self: Arc<Server>,
 606        request: Message<proto::JoinRoom>,
 607        response: Response<proto::JoinRoom>,
 608    ) -> Result<()> {
 609        {
 610            let mut store = self.store().await;
 611            let (room, recipient_connection_ids) =
 612                store.join_room(request.payload.id, request.sender_connection_id)?;
 613            for recipient_id in recipient_connection_ids {
 614                self.peer
 615                    .send(recipient_id, proto::CallCanceled {})
 616                    .trace_err();
 617            }
 618
 619            let live_kit_connection_info =
 620                if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 621                    if let Some(token) = live_kit
 622                        .room_token(
 623                            &room.live_kit_room,
 624                            &request.sender_connection_id.to_string(),
 625                        )
 626                        .trace_err()
 627                    {
 628                        Some(proto::LiveKitConnectionInfo {
 629                            server_url: live_kit.url().into(),
 630                            token,
 631                        })
 632                    } else {
 633                        None
 634                    }
 635                } else {
 636                    None
 637                };
 638
 639            response.send(proto::JoinRoomResponse {
 640                room: Some(room.clone()),
 641                live_kit_connection_info,
 642            })?;
 643            self.room_updated(room);
 644        }
 645        self.update_user_contacts(request.sender_user_id).await?;
 646        Ok(())
 647    }
 648
 649    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 650        let mut contacts_to_update = HashSet::default();
 651        let room_left;
 652        {
 653            let mut store = self.store().await;
 654            let left_room = store.leave_room(message.payload.id, message.sender_connection_id)?;
 655            contacts_to_update.insert(message.sender_user_id);
 656
 657            for project in left_room.unshared_projects {
 658                for connection_id in project.connection_ids() {
 659                    self.peer.send(
 660                        connection_id,
 661                        proto::UnshareProject {
 662                            project_id: project.id.to_proto(),
 663                        },
 664                    )?;
 665                }
 666            }
 667
 668            for project in left_room.left_projects {
 669                if project.remove_collaborator {
 670                    for connection_id in project.connection_ids {
 671                        self.peer.send(
 672                            connection_id,
 673                            proto::RemoveProjectCollaborator {
 674                                project_id: project.id.to_proto(),
 675                                peer_id: message.sender_connection_id.0,
 676                            },
 677                        )?;
 678                    }
 679
 680                    self.peer.send(
 681                        message.sender_connection_id,
 682                        proto::UnshareProject {
 683                            project_id: project.id.to_proto(),
 684                        },
 685                    )?;
 686                }
 687            }
 688
 689            self.room_updated(&left_room.room);
 690            room_left = self.room_left(&left_room.room, message.sender_connection_id);
 691
 692            for connection_id in left_room.canceled_call_connection_ids {
 693                self.peer
 694                    .send(connection_id, proto::CallCanceled {})
 695                    .trace_err();
 696                contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
 697            }
 698        }
 699
 700        room_left.await.trace_err();
 701        for user_id in contacts_to_update {
 702            self.update_user_contacts(user_id).await?;
 703        }
 704
 705        Ok(())
 706    }
 707
 708    async fn call(
 709        self: Arc<Server>,
 710        request: Message<proto::Call>,
 711        response: Response<proto::Call>,
 712    ) -> Result<()> {
 713        let room_id = RoomId::from_proto(request.payload.room_id);
 714        let calling_user_id = request.sender_user_id;
 715        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 716        let initial_project_id = request
 717            .payload
 718            .initial_project_id
 719            .map(ProjectId::from_proto);
 720        if !self
 721            .app_state
 722            .db
 723            .has_contact(calling_user_id, called_user_id)
 724            .await?
 725        {
 726            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 727        }
 728
 729        let room = self
 730            .app_state
 731            .db
 732            .call(room_id, calling_user_id, called_user_id, initial_project_id)
 733            .await?;
 734        self.room_updated(&room);
 735        self.update_user_contacts(called_user_id).await?;
 736
 737        let incoming_call = proto::IncomingCall {
 738            room_id: room_id.to_proto(),
 739            calling_user_id: calling_user_id.to_proto(),
 740            participant_user_ids: room
 741                .participants
 742                .iter()
 743                .map(|participant| participant.user_id)
 744                .collect(),
 745            initial_project: room.participants.iter().find_map(|participant| {
 746                let initial_project_id = initial_project_id?.to_proto();
 747                participant
 748                    .projects
 749                    .iter()
 750                    .find(|project| project.id == initial_project_id)
 751                    .cloned()
 752            }),
 753        };
 754
 755        let mut calls = self
 756            .store()
 757            .await
 758            .connection_ids_for_user(called_user_id)
 759            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 760            .collect::<FuturesUnordered<_>>();
 761
 762        while let Some(call_response) = calls.next().await {
 763            match call_response.as_ref() {
 764                Ok(_) => {
 765                    response.send(proto::Ack {})?;
 766                    return Ok(());
 767                }
 768                Err(_) => {
 769                    call_response.trace_err();
 770                }
 771            }
 772        }
 773
 774        let room = self
 775            .app_state
 776            .db
 777            .call_failed(room_id, called_user_id)
 778            .await?;
 779        self.room_updated(&room);
 780        self.update_user_contacts(called_user_id).await?;
 781
 782        Err(anyhow!("failed to ring call recipient"))?
 783    }
 784
 785    async fn cancel_call(
 786        self: Arc<Server>,
 787        request: Message<proto::CancelCall>,
 788        response: Response<proto::CancelCall>,
 789    ) -> Result<()> {
 790        let recipient_user_id = UserId::from_proto(request.payload.called_user_id);
 791        {
 792            let mut store = self.store().await;
 793            let (room, recipient_connection_ids) = store.cancel_call(
 794                request.payload.room_id,
 795                recipient_user_id,
 796                request.sender_connection_id,
 797            )?;
 798            for recipient_id in recipient_connection_ids {
 799                self.peer
 800                    .send(recipient_id, proto::CallCanceled {})
 801                    .trace_err();
 802            }
 803            self.room_updated(room);
 804            response.send(proto::Ack {})?;
 805        }
 806        self.update_user_contacts(recipient_user_id).await?;
 807        Ok(())
 808    }
 809
 810    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 811        let recipient_user_id = message.sender_user_id;
 812        {
 813            let mut store = self.store().await;
 814            let (room, recipient_connection_ids) =
 815                store.decline_call(message.payload.room_id, message.sender_connection_id)?;
 816            for recipient_id in recipient_connection_ids {
 817                self.peer
 818                    .send(recipient_id, proto::CallCanceled {})
 819                    .trace_err();
 820            }
 821            self.room_updated(room);
 822        }
 823        self.update_user_contacts(recipient_user_id).await?;
 824        Ok(())
 825    }
 826
 827    async fn update_participant_location(
 828        self: Arc<Server>,
 829        request: Message<proto::UpdateParticipantLocation>,
 830        response: Response<proto::UpdateParticipantLocation>,
 831    ) -> Result<()> {
 832        let room_id = RoomId::from_proto(request.payload.room_id);
 833        let location = request
 834            .payload
 835            .location
 836            .ok_or_else(|| anyhow!("invalid location"))?;
 837        let room = self
 838            .app_state
 839            .db
 840            .update_room_participant_location(room_id, request.sender_user_id, location)
 841            .await?;
 842        self.room_updated(&room);
 843        response.send(proto::Ack {})?;
 844        Ok(())
 845    }
 846
 847    fn room_updated(&self, room: &proto::Room) {
 848        for participant in &room.participants {
 849            self.peer
 850                .send(
 851                    ConnectionId(participant.peer_id),
 852                    proto::RoomUpdated {
 853                        room: Some(room.clone()),
 854                    },
 855                )
 856                .trace_err();
 857        }
 858    }
 859
 860    fn room_left(
 861        &self,
 862        room: &proto::Room,
 863        connection_id: ConnectionId,
 864    ) -> impl Future<Output = Result<()>> {
 865        let client = self.app_state.live_kit_client.clone();
 866        let room_name = room.live_kit_room.clone();
 867        let participant_count = room.participants.len();
 868        async move {
 869            if let Some(client) = client {
 870                client
 871                    .remove_participant(room_name.clone(), connection_id.to_string())
 872                    .await?;
 873
 874                if participant_count == 0 {
 875                    client.delete_room(room_name).await?;
 876                }
 877            }
 878
 879            Ok(())
 880        }
 881    }
 882
 883    async fn share_project(
 884        self: Arc<Server>,
 885        request: Message<proto::ShareProject>,
 886        response: Response<proto::ShareProject>,
 887    ) -> Result<()> {
 888        let (project_id, room) = self
 889            .app_state
 890            .db
 891            .share_project(
 892                request.sender_user_id,
 893                request.sender_connection_id,
 894                RoomId::from_proto(request.payload.room_id),
 895                &request.payload.worktrees,
 896            )
 897            .await?;
 898        response.send(proto::ShareProjectResponse {
 899            project_id: project_id.to_proto(),
 900        })?;
 901        self.room_updated(&room);
 902
 903        Ok(())
 904    }
 905
 906    async fn unshare_project(
 907        self: Arc<Server>,
 908        message: Message<proto::UnshareProject>,
 909    ) -> Result<()> {
 910        let project_id = ProjectId::from_proto(message.payload.project_id);
 911        let mut store = self.store().await;
 912        let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?;
 913        broadcast(
 914            message.sender_connection_id,
 915            project.guest_connection_ids(),
 916            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 917        );
 918        self.room_updated(room);
 919
 920        Ok(())
 921    }
 922
 923    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 924        let contacts = self.app_state.db.get_contacts(user_id).await?;
 925        let store = self.store().await;
 926        let updated_contact = store.contact_for_user(user_id, false);
 927        for contact in contacts {
 928            if let db::Contact::Accepted {
 929                user_id: contact_user_id,
 930                ..
 931            } = contact
 932            {
 933                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 934                    self.peer
 935                        .send(
 936                            contact_conn_id,
 937                            proto::UpdateContacts {
 938                                contacts: vec![updated_contact.clone()],
 939                                remove_contacts: Default::default(),
 940                                incoming_requests: Default::default(),
 941                                remove_incoming_requests: Default::default(),
 942                                outgoing_requests: Default::default(),
 943                                remove_outgoing_requests: Default::default(),
 944                            },
 945                        )
 946                        .trace_err();
 947                }
 948            }
 949        }
 950        Ok(())
 951    }
 952
 953    async fn join_project(
 954        self: Arc<Server>,
 955        request: Message<proto::JoinProject>,
 956        response: Response<proto::JoinProject>,
 957    ) -> Result<()> {
 958        let project_id = ProjectId::from_proto(request.payload.project_id);
 959        let guest_user_id = request.sender_user_id;
 960        let host_user_id;
 961        let host_connection_id;
 962        {
 963            let state = self.store().await;
 964            let project = state.project(project_id)?;
 965            host_user_id = project.host.user_id;
 966            host_connection_id = project.host_connection_id;
 967        };
 968
 969        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 970
 971        let mut store = self.store().await;
 972        let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?;
 973        let peer_count = project.guests.len();
 974        let mut collaborators = Vec::with_capacity(peer_count);
 975        collaborators.push(proto::Collaborator {
 976            peer_id: project.host_connection_id.0,
 977            replica_id: 0,
 978            user_id: project.host.user_id.to_proto(),
 979        });
 980        let worktrees = project
 981            .worktrees
 982            .iter()
 983            .map(|(id, worktree)| proto::WorktreeMetadata {
 984                id: *id,
 985                root_name: worktree.root_name.clone(),
 986                visible: worktree.visible,
 987                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
 988            })
 989            .collect::<Vec<_>>();
 990
 991        // Add all guests other than the requesting user's own connections as collaborators
 992        for (guest_conn_id, guest) in &project.guests {
 993            if request.sender_connection_id != *guest_conn_id {
 994                collaborators.push(proto::Collaborator {
 995                    peer_id: guest_conn_id.0,
 996                    replica_id: guest.replica_id as u32,
 997                    user_id: guest.user_id.to_proto(),
 998                });
 999            }
1000        }
1001
1002        for conn_id in project.connection_ids() {
1003            if conn_id != request.sender_connection_id {
1004                self.peer
1005                    .send(
1006                        conn_id,
1007                        proto::AddProjectCollaborator {
1008                            project_id: project_id.to_proto(),
1009                            collaborator: Some(proto::Collaborator {
1010                                peer_id: request.sender_connection_id.0,
1011                                replica_id: replica_id as u32,
1012                                user_id: guest_user_id.to_proto(),
1013                            }),
1014                        },
1015                    )
1016                    .trace_err();
1017            }
1018        }
1019
1020        // First, we send the metadata associated with each worktree.
1021        response.send(proto::JoinProjectResponse {
1022            worktrees: worktrees.clone(),
1023            replica_id: replica_id as u32,
1024            collaborators: collaborators.clone(),
1025            language_servers: project.language_servers.clone(),
1026        })?;
1027
1028        for (worktree_id, worktree) in &project.worktrees {
1029            #[cfg(any(test, feature = "test-support"))]
1030            const MAX_CHUNK_SIZE: usize = 2;
1031            #[cfg(not(any(test, feature = "test-support")))]
1032            const MAX_CHUNK_SIZE: usize = 256;
1033
1034            // Stream this worktree's entries.
1035            let message = proto::UpdateWorktree {
1036                project_id: project_id.to_proto(),
1037                worktree_id: *worktree_id,
1038                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
1039                root_name: worktree.root_name.clone(),
1040                updated_entries: worktree.entries.values().cloned().collect(),
1041                removed_entries: Default::default(),
1042                scan_id: worktree.scan_id,
1043                is_last_update: worktree.is_complete,
1044            };
1045            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1046                self.peer
1047                    .send(request.sender_connection_id, update.clone())?;
1048            }
1049
1050            // Stream this worktree's diagnostics.
1051            for summary in worktree.diagnostic_summaries.values() {
1052                self.peer.send(
1053                    request.sender_connection_id,
1054                    proto::UpdateDiagnosticSummary {
1055                        project_id: project_id.to_proto(),
1056                        worktree_id: *worktree_id,
1057                        summary: Some(summary.clone()),
1058                    },
1059                )?;
1060            }
1061        }
1062
1063        for language_server in &project.language_servers {
1064            self.peer.send(
1065                request.sender_connection_id,
1066                proto::UpdateLanguageServer {
1067                    project_id: project_id.to_proto(),
1068                    language_server_id: language_server.id,
1069                    variant: Some(
1070                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1071                            proto::LspDiskBasedDiagnosticsUpdated {},
1072                        ),
1073                    ),
1074                },
1075            )?;
1076        }
1077
1078        Ok(())
1079    }
1080
1081    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1082        let sender_id = request.sender_connection_id;
1083        let project_id = ProjectId::from_proto(request.payload.project_id);
1084        let project;
1085        {
1086            let mut store = self.store().await;
1087            project = store.leave_project(project_id, sender_id)?;
1088            tracing::info!(
1089                %project_id,
1090                host_user_id = %project.host_user_id,
1091                host_connection_id = %project.host_connection_id,
1092                "leave project"
1093            );
1094
1095            if project.remove_collaborator {
1096                broadcast(sender_id, project.connection_ids, |conn_id| {
1097                    self.peer.send(
1098                        conn_id,
1099                        proto::RemoveProjectCollaborator {
1100                            project_id: project_id.to_proto(),
1101                            peer_id: sender_id.0,
1102                        },
1103                    )
1104                });
1105            }
1106        }
1107
1108        Ok(())
1109    }
1110
1111    async fn update_project(
1112        self: Arc<Server>,
1113        request: Message<proto::UpdateProject>,
1114    ) -> Result<()> {
1115        let project_id = ProjectId::from_proto(request.payload.project_id);
1116        {
1117            let mut state = self.store().await;
1118            let guest_connection_ids = state
1119                .read_project(project_id, request.sender_connection_id)?
1120                .guest_connection_ids();
1121            let room = state.update_project(
1122                project_id,
1123                &request.payload.worktrees,
1124                request.sender_connection_id,
1125            )?;
1126            broadcast(
1127                request.sender_connection_id,
1128                guest_connection_ids,
1129                |connection_id| {
1130                    self.peer.forward_send(
1131                        request.sender_connection_id,
1132                        connection_id,
1133                        request.payload.clone(),
1134                    )
1135                },
1136            );
1137            self.room_updated(room);
1138        };
1139
1140        Ok(())
1141    }
1142
1143    async fn update_worktree(
1144        self: Arc<Server>,
1145        request: Message<proto::UpdateWorktree>,
1146        response: Response<proto::UpdateWorktree>,
1147    ) -> Result<()> {
1148        let project_id = ProjectId::from_proto(request.payload.project_id);
1149        let worktree_id = request.payload.worktree_id;
1150        let connection_ids = self.store().await.update_worktree(
1151            request.sender_connection_id,
1152            project_id,
1153            worktree_id,
1154            &request.payload.root_name,
1155            &request.payload.removed_entries,
1156            &request.payload.updated_entries,
1157            request.payload.scan_id,
1158            request.payload.is_last_update,
1159        )?;
1160
1161        broadcast(
1162            request.sender_connection_id,
1163            connection_ids,
1164            |connection_id| {
1165                self.peer.forward_send(
1166                    request.sender_connection_id,
1167                    connection_id,
1168                    request.payload.clone(),
1169                )
1170            },
1171        );
1172        response.send(proto::Ack {})?;
1173        Ok(())
1174    }
1175
1176    async fn update_diagnostic_summary(
1177        self: Arc<Server>,
1178        request: Message<proto::UpdateDiagnosticSummary>,
1179    ) -> Result<()> {
1180        let summary = request
1181            .payload
1182            .summary
1183            .clone()
1184            .ok_or_else(|| anyhow!("invalid summary"))?;
1185        let receiver_ids = self.store().await.update_diagnostic_summary(
1186            ProjectId::from_proto(request.payload.project_id),
1187            request.payload.worktree_id,
1188            request.sender_connection_id,
1189            summary,
1190        )?;
1191
1192        broadcast(
1193            request.sender_connection_id,
1194            receiver_ids,
1195            |connection_id| {
1196                self.peer.forward_send(
1197                    request.sender_connection_id,
1198                    connection_id,
1199                    request.payload.clone(),
1200                )
1201            },
1202        );
1203        Ok(())
1204    }
1205
1206    async fn start_language_server(
1207        self: Arc<Server>,
1208        request: Message<proto::StartLanguageServer>,
1209    ) -> Result<()> {
1210        let receiver_ids = self.store().await.start_language_server(
1211            ProjectId::from_proto(request.payload.project_id),
1212            request.sender_connection_id,
1213            request
1214                .payload
1215                .server
1216                .clone()
1217                .ok_or_else(|| anyhow!("invalid language server"))?,
1218        )?;
1219        broadcast(
1220            request.sender_connection_id,
1221            receiver_ids,
1222            |connection_id| {
1223                self.peer.forward_send(
1224                    request.sender_connection_id,
1225                    connection_id,
1226                    request.payload.clone(),
1227                )
1228            },
1229        );
1230        Ok(())
1231    }
1232
1233    async fn update_language_server(
1234        self: Arc<Server>,
1235        request: Message<proto::UpdateLanguageServer>,
1236    ) -> Result<()> {
1237        let receiver_ids = self.store().await.project_connection_ids(
1238            ProjectId::from_proto(request.payload.project_id),
1239            request.sender_connection_id,
1240        )?;
1241        broadcast(
1242            request.sender_connection_id,
1243            receiver_ids,
1244            |connection_id| {
1245                self.peer.forward_send(
1246                    request.sender_connection_id,
1247                    connection_id,
1248                    request.payload.clone(),
1249                )
1250            },
1251        );
1252        Ok(())
1253    }
1254
1255    async fn forward_project_request<T>(
1256        self: Arc<Server>,
1257        request: Message<T>,
1258        response: Response<T>,
1259    ) -> Result<()>
1260    where
1261        T: EntityMessage + RequestMessage,
1262    {
1263        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1264        let host_connection_id = self
1265            .store()
1266            .await
1267            .read_project(project_id, request.sender_connection_id)?
1268            .host_connection_id;
1269        let payload = self
1270            .peer
1271            .forward_request(
1272                request.sender_connection_id,
1273                host_connection_id,
1274                request.payload,
1275            )
1276            .await?;
1277
1278        // Ensure project still exists by the time we get the response from the host.
1279        self.store()
1280            .await
1281            .read_project(project_id, request.sender_connection_id)?;
1282
1283        response.send(payload)?;
1284        Ok(())
1285    }
1286
1287    async fn save_buffer(
1288        self: Arc<Server>,
1289        request: Message<proto::SaveBuffer>,
1290        response: Response<proto::SaveBuffer>,
1291    ) -> Result<()> {
1292        let project_id = ProjectId::from_proto(request.payload.project_id);
1293        let host = self
1294            .store()
1295            .await
1296            .read_project(project_id, request.sender_connection_id)?
1297            .host_connection_id;
1298        let response_payload = self
1299            .peer
1300            .forward_request(request.sender_connection_id, host, request.payload.clone())
1301            .await?;
1302
1303        let mut guests = self
1304            .store()
1305            .await
1306            .read_project(project_id, request.sender_connection_id)?
1307            .connection_ids();
1308        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
1309        broadcast(host, guests, |conn_id| {
1310            self.peer
1311                .forward_send(host, conn_id, response_payload.clone())
1312        });
1313        response.send(response_payload)?;
1314        Ok(())
1315    }
1316
1317    async fn create_buffer_for_peer(
1318        self: Arc<Server>,
1319        request: Message<proto::CreateBufferForPeer>,
1320    ) -> Result<()> {
1321        self.peer.forward_send(
1322            request.sender_connection_id,
1323            ConnectionId(request.payload.peer_id),
1324            request.payload,
1325        )?;
1326        Ok(())
1327    }
1328
1329    async fn update_buffer(
1330        self: Arc<Server>,
1331        request: Message<proto::UpdateBuffer>,
1332        response: Response<proto::UpdateBuffer>,
1333    ) -> Result<()> {
1334        let project_id = ProjectId::from_proto(request.payload.project_id);
1335        let receiver_ids = {
1336            let store = self.store().await;
1337            store.project_connection_ids(project_id, request.sender_connection_id)?
1338        };
1339
1340        broadcast(
1341            request.sender_connection_id,
1342            receiver_ids,
1343            |connection_id| {
1344                self.peer.forward_send(
1345                    request.sender_connection_id,
1346                    connection_id,
1347                    request.payload.clone(),
1348                )
1349            },
1350        );
1351        response.send(proto::Ack {})?;
1352        Ok(())
1353    }
1354
1355    async fn update_buffer_file(
1356        self: Arc<Server>,
1357        request: Message<proto::UpdateBufferFile>,
1358    ) -> Result<()> {
1359        let receiver_ids = self.store().await.project_connection_ids(
1360            ProjectId::from_proto(request.payload.project_id),
1361            request.sender_connection_id,
1362        )?;
1363        broadcast(
1364            request.sender_connection_id,
1365            receiver_ids,
1366            |connection_id| {
1367                self.peer.forward_send(
1368                    request.sender_connection_id,
1369                    connection_id,
1370                    request.payload.clone(),
1371                )
1372            },
1373        );
1374        Ok(())
1375    }
1376
1377    async fn buffer_reloaded(
1378        self: Arc<Server>,
1379        request: Message<proto::BufferReloaded>,
1380    ) -> Result<()> {
1381        let receiver_ids = self.store().await.project_connection_ids(
1382            ProjectId::from_proto(request.payload.project_id),
1383            request.sender_connection_id,
1384        )?;
1385        broadcast(
1386            request.sender_connection_id,
1387            receiver_ids,
1388            |connection_id| {
1389                self.peer.forward_send(
1390                    request.sender_connection_id,
1391                    connection_id,
1392                    request.payload.clone(),
1393                )
1394            },
1395        );
1396        Ok(())
1397    }
1398
1399    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1400        let receiver_ids = self.store().await.project_connection_ids(
1401            ProjectId::from_proto(request.payload.project_id),
1402            request.sender_connection_id,
1403        )?;
1404        broadcast(
1405            request.sender_connection_id,
1406            receiver_ids,
1407            |connection_id| {
1408                self.peer.forward_send(
1409                    request.sender_connection_id,
1410                    connection_id,
1411                    request.payload.clone(),
1412                )
1413            },
1414        );
1415        Ok(())
1416    }
1417
1418    async fn follow(
1419        self: Arc<Self>,
1420        request: Message<proto::Follow>,
1421        response: Response<proto::Follow>,
1422    ) -> Result<()> {
1423        let project_id = ProjectId::from_proto(request.payload.project_id);
1424        let leader_id = ConnectionId(request.payload.leader_id);
1425        let follower_id = request.sender_connection_id;
1426        {
1427            let store = self.store().await;
1428            if !store
1429                .project_connection_ids(project_id, follower_id)?
1430                .contains(&leader_id)
1431            {
1432                Err(anyhow!("no such peer"))?;
1433            }
1434        }
1435
1436        let mut response_payload = self
1437            .peer
1438            .forward_request(request.sender_connection_id, leader_id, request.payload)
1439            .await?;
1440        response_payload
1441            .views
1442            .retain(|view| view.leader_id != Some(follower_id.0));
1443        response.send(response_payload)?;
1444        Ok(())
1445    }
1446
1447    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1448        let project_id = ProjectId::from_proto(request.payload.project_id);
1449        let leader_id = ConnectionId(request.payload.leader_id);
1450        let store = self.store().await;
1451        if !store
1452            .project_connection_ids(project_id, request.sender_connection_id)?
1453            .contains(&leader_id)
1454        {
1455            Err(anyhow!("no such peer"))?;
1456        }
1457        self.peer
1458            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1459        Ok(())
1460    }
1461
1462    async fn update_followers(
1463        self: Arc<Self>,
1464        request: Message<proto::UpdateFollowers>,
1465    ) -> Result<()> {
1466        let project_id = ProjectId::from_proto(request.payload.project_id);
1467        let store = self.store().await;
1468        let connection_ids =
1469            store.project_connection_ids(project_id, request.sender_connection_id)?;
1470        let leader_id = request
1471            .payload
1472            .variant
1473            .as_ref()
1474            .and_then(|variant| match variant {
1475                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1476                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1477                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1478            });
1479        for follower_id in &request.payload.follower_ids {
1480            let follower_id = ConnectionId(*follower_id);
1481            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1482                self.peer.forward_send(
1483                    request.sender_connection_id,
1484                    follower_id,
1485                    request.payload.clone(),
1486                )?;
1487            }
1488        }
1489        Ok(())
1490    }
1491
1492    async fn get_users(
1493        self: Arc<Server>,
1494        request: Message<proto::GetUsers>,
1495        response: Response<proto::GetUsers>,
1496    ) -> Result<()> {
1497        let user_ids = request
1498            .payload
1499            .user_ids
1500            .into_iter()
1501            .map(UserId::from_proto)
1502            .collect();
1503        let users = self
1504            .app_state
1505            .db
1506            .get_users_by_ids(user_ids)
1507            .await?
1508            .into_iter()
1509            .map(|user| proto::User {
1510                id: user.id.to_proto(),
1511                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1512                github_login: user.github_login,
1513            })
1514            .collect();
1515        response.send(proto::UsersResponse { users })?;
1516        Ok(())
1517    }
1518
1519    async fn fuzzy_search_users(
1520        self: Arc<Server>,
1521        request: Message<proto::FuzzySearchUsers>,
1522        response: Response<proto::FuzzySearchUsers>,
1523    ) -> Result<()> {
1524        let query = request.payload.query;
1525        let db = &self.app_state.db;
1526        let users = match query.len() {
1527            0 => vec![],
1528            1 | 2 => db
1529                .get_user_by_github_account(&query, None)
1530                .await?
1531                .into_iter()
1532                .collect(),
1533            _ => db.fuzzy_search_users(&query, 10).await?,
1534        };
1535        let users = users
1536            .into_iter()
1537            .filter(|user| user.id != request.sender_user_id)
1538            .map(|user| proto::User {
1539                id: user.id.to_proto(),
1540                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1541                github_login: user.github_login,
1542            })
1543            .collect();
1544        response.send(proto::UsersResponse { users })?;
1545        Ok(())
1546    }
1547
1548    async fn request_contact(
1549        self: Arc<Server>,
1550        request: Message<proto::RequestContact>,
1551        response: Response<proto::RequestContact>,
1552    ) -> Result<()> {
1553        let requester_id = request.sender_user_id;
1554        let responder_id = UserId::from_proto(request.payload.responder_id);
1555        if requester_id == responder_id {
1556            return Err(anyhow!("cannot add yourself as a contact"))?;
1557        }
1558
1559        self.app_state
1560            .db
1561            .send_contact_request(requester_id, responder_id)
1562            .await?;
1563
1564        // Update outgoing contact requests of requester
1565        let mut update = proto::UpdateContacts::default();
1566        update.outgoing_requests.push(responder_id.to_proto());
1567        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1568            self.peer.send(connection_id, update.clone())?;
1569        }
1570
1571        // Update incoming contact requests of responder
1572        let mut update = proto::UpdateContacts::default();
1573        update
1574            .incoming_requests
1575            .push(proto::IncomingContactRequest {
1576                requester_id: requester_id.to_proto(),
1577                should_notify: true,
1578            });
1579        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1580            self.peer.send(connection_id, update.clone())?;
1581        }
1582
1583        response.send(proto::Ack {})?;
1584        Ok(())
1585    }
1586
1587    async fn respond_to_contact_request(
1588        self: Arc<Server>,
1589        request: Message<proto::RespondToContactRequest>,
1590        response: Response<proto::RespondToContactRequest>,
1591    ) -> Result<()> {
1592        let responder_id = request.sender_user_id;
1593        let requester_id = UserId::from_proto(request.payload.requester_id);
1594        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1595            self.app_state
1596                .db
1597                .dismiss_contact_notification(responder_id, requester_id)
1598                .await?;
1599        } else {
1600            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1601            self.app_state
1602                .db
1603                .respond_to_contact_request(responder_id, requester_id, accept)
1604                .await?;
1605
1606            let store = self.store().await;
1607            // Update responder with new contact
1608            let mut update = proto::UpdateContacts::default();
1609            if accept {
1610                update
1611                    .contacts
1612                    .push(store.contact_for_user(requester_id, false));
1613            }
1614            update
1615                .remove_incoming_requests
1616                .push(requester_id.to_proto());
1617            for connection_id in store.connection_ids_for_user(responder_id) {
1618                self.peer.send(connection_id, update.clone())?;
1619            }
1620
1621            // Update requester with new contact
1622            let mut update = proto::UpdateContacts::default();
1623            if accept {
1624                update
1625                    .contacts
1626                    .push(store.contact_for_user(responder_id, true));
1627            }
1628            update
1629                .remove_outgoing_requests
1630                .push(responder_id.to_proto());
1631            for connection_id in store.connection_ids_for_user(requester_id) {
1632                self.peer.send(connection_id, update.clone())?;
1633            }
1634        }
1635
1636        response.send(proto::Ack {})?;
1637        Ok(())
1638    }
1639
1640    async fn remove_contact(
1641        self: Arc<Server>,
1642        request: Message<proto::RemoveContact>,
1643        response: Response<proto::RemoveContact>,
1644    ) -> Result<()> {
1645        let requester_id = request.sender_user_id;
1646        let responder_id = UserId::from_proto(request.payload.user_id);
1647        self.app_state
1648            .db
1649            .remove_contact(requester_id, responder_id)
1650            .await?;
1651
1652        // Update outgoing contact requests of requester
1653        let mut update = proto::UpdateContacts::default();
1654        update
1655            .remove_outgoing_requests
1656            .push(responder_id.to_proto());
1657        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1658            self.peer.send(connection_id, update.clone())?;
1659        }
1660
1661        // Update incoming contact requests of responder
1662        let mut update = proto::UpdateContacts::default();
1663        update
1664            .remove_incoming_requests
1665            .push(requester_id.to_proto());
1666        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1667            self.peer.send(connection_id, update.clone())?;
1668        }
1669
1670        response.send(proto::Ack {})?;
1671        Ok(())
1672    }
1673
1674    async fn update_diff_base(
1675        self: Arc<Server>,
1676        request: Message<proto::UpdateDiffBase>,
1677    ) -> Result<()> {
1678        let receiver_ids = self.store().await.project_connection_ids(
1679            ProjectId::from_proto(request.payload.project_id),
1680            request.sender_connection_id,
1681        )?;
1682        broadcast(
1683            request.sender_connection_id,
1684            receiver_ids,
1685            |connection_id| {
1686                self.peer.forward_send(
1687                    request.sender_connection_id,
1688                    connection_id,
1689                    request.payload.clone(),
1690                )
1691            },
1692        );
1693        Ok(())
1694    }
1695
1696    async fn get_private_user_info(
1697        self: Arc<Self>,
1698        request: Message<proto::GetPrivateUserInfo>,
1699        response: Response<proto::GetPrivateUserInfo>,
1700    ) -> Result<()> {
1701        let metrics_id = self
1702            .app_state
1703            .db
1704            .get_user_metrics_id(request.sender_user_id)
1705            .await?;
1706        let user = self
1707            .app_state
1708            .db
1709            .get_user_by_id(request.sender_user_id)
1710            .await?
1711            .ok_or_else(|| anyhow!("user not found"))?;
1712        response.send(proto::GetPrivateUserInfoResponse {
1713            metrics_id,
1714            staff: user.admin,
1715        })?;
1716        Ok(())
1717    }
1718
1719    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1720        #[cfg(test)]
1721        tokio::task::yield_now().await;
1722        let guard = self.store.lock().await;
1723        #[cfg(test)]
1724        tokio::task::yield_now().await;
1725        StoreGuard {
1726            guard,
1727            _not_send: PhantomData,
1728        }
1729    }
1730
1731    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1732        ServerSnapshot {
1733            store: self.store().await,
1734            peer: &self.peer,
1735        }
1736    }
1737}
1738
1739impl<'a> Deref for StoreGuard<'a> {
1740    type Target = Store;
1741
1742    fn deref(&self) -> &Self::Target {
1743        &*self.guard
1744    }
1745}
1746
1747impl<'a> DerefMut for StoreGuard<'a> {
1748    fn deref_mut(&mut self) -> &mut Self::Target {
1749        &mut *self.guard
1750    }
1751}
1752
1753impl<'a> Drop for StoreGuard<'a> {
1754    fn drop(&mut self) {
1755        #[cfg(test)]
1756        self.check_invariants();
1757    }
1758}
1759
1760impl Executor for RealExecutor {
1761    type Sleep = Sleep;
1762
1763    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1764        tokio::task::spawn(future);
1765    }
1766
1767    fn sleep(&self, duration: Duration) -> Self::Sleep {
1768        tokio::time::sleep(duration)
1769    }
1770}
1771
1772fn broadcast<F>(
1773    sender_id: ConnectionId,
1774    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1775    mut f: F,
1776) where
1777    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1778{
1779    for receiver_id in receiver_ids {
1780        if receiver_id != sender_id {
1781            f(receiver_id).trace_err();
1782        }
1783    }
1784}
1785
1786lazy_static! {
1787    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1788}
1789
1790pub struct ProtocolVersion(u32);
1791
1792impl Header for ProtocolVersion {
1793    fn name() -> &'static HeaderName {
1794        &ZED_PROTOCOL_VERSION
1795    }
1796
1797    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1798    where
1799        Self: Sized,
1800        I: Iterator<Item = &'i axum::http::HeaderValue>,
1801    {
1802        let version = values
1803            .next()
1804            .ok_or_else(axum::headers::Error::invalid)?
1805            .to_str()
1806            .map_err(|_| axum::headers::Error::invalid())?
1807            .parse()
1808            .map_err(|_| axum::headers::Error::invalid())?;
1809        Ok(Self(version))
1810    }
1811
1812    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1813        values.extend([self.0.to_string().parse().unwrap()]);
1814    }
1815}
1816
1817pub fn routes(server: Arc<Server>) -> Router<Body> {
1818    Router::new()
1819        .route("/rpc", get(handle_websocket_request))
1820        .layer(
1821            ServiceBuilder::new()
1822                .layer(Extension(server.app_state.clone()))
1823                .layer(middleware::from_fn(auth::validate_header)),
1824        )
1825        .route("/metrics", get(handle_metrics))
1826        .layer(Extension(server))
1827}
1828
1829pub async fn handle_websocket_request(
1830    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1831    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1832    Extension(server): Extension<Arc<Server>>,
1833    Extension(user): Extension<User>,
1834    ws: WebSocketUpgrade,
1835) -> axum::response::Response {
1836    if protocol_version != rpc::PROTOCOL_VERSION {
1837        return (
1838            StatusCode::UPGRADE_REQUIRED,
1839            "client must be upgraded".to_string(),
1840        )
1841            .into_response();
1842    }
1843    let socket_address = socket_address.to_string();
1844    ws.on_upgrade(move |socket| {
1845        use util::ResultExt;
1846        let socket = socket
1847            .map_ok(to_tungstenite_message)
1848            .err_into()
1849            .with(|message| async move { Ok(to_axum_message(message)) });
1850        let connection = Connection::new(Box::pin(socket));
1851        async move {
1852            server
1853                .handle_connection(connection, socket_address, user, None, RealExecutor)
1854                .await
1855                .log_err();
1856        }
1857    })
1858}
1859
1860pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1861    let metrics = server.store().await.metrics();
1862    METRIC_CONNECTIONS.set(metrics.connections as _);
1863    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1864
1865    let encoder = prometheus::TextEncoder::new();
1866    let metric_families = prometheus::gather();
1867    match encoder.encode_to_string(&metric_families) {
1868        Ok(string) => (StatusCode::OK, string).into_response(),
1869        Err(error) => (
1870            StatusCode::INTERNAL_SERVER_ERROR,
1871            format!("failed to encode metrics {:?}", error),
1872        )
1873            .into_response(),
1874    }
1875}
1876
1877fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1878    match message {
1879        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1880        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1881        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1882        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1883        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1884            code: frame.code.into(),
1885            reason: frame.reason,
1886        })),
1887    }
1888}
1889
1890fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1891    match message {
1892        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1893        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1894        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1895        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1896        AxumMessage::Close(frame) => {
1897            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1898                code: frame.code.into(),
1899                reason: frame.reason,
1900            }))
1901        }
1902    }
1903}
1904
1905pub trait ResultExt {
1906    type Ok;
1907
1908    fn trace_err(self) -> Option<Self::Ok>;
1909}
1910
1911impl<T, E> ResultExt for Result<T, E>
1912where
1913    E: std::fmt::Debug,
1914{
1915    type Ok = T;
1916
1917    fn trace_err(self) -> Option<T> {
1918        match self {
1919            Ok(value) => Some(value),
1920            Err(error) => {
1921                tracing::error!("{:?}", error);
1922                None
1923            }
1924        }
1925    }
1926}