rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, ProjectId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    os::unix::prelude::OsStrExt,
  46    rc::Rc,
  47    sync::{
  48        atomic::{AtomicBool, Ordering::SeqCst},
  49        Arc,
  50    },
  51    time::Duration,
  52};
  53use time::OffsetDateTime;
  54use tokio::{
  55    sync::{Mutex, MutexGuard},
  56    time::Sleep,
  57};
  58use tower::ServiceBuilder;
  59use tracing::{info_span, instrument, Instrument};
  60
  61pub use store::{Store, Worktree};
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_REGISTERED_PROJECTS: IntGauge =
  67        register_int_gauge!("registered_projects", "number of registered projects").unwrap();
  68    static ref METRIC_ACTIVE_PROJECTS: IntGauge =
  69        register_int_gauge!("active_projects", "number of active projects").unwrap();
  70    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  71        "shared_projects",
  72        "number of open projects with one or more guests"
  73    )
  74    .unwrap();
  75}
  76
  77type MessageHandler =
  78    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  79
  80struct Response<R> {
  81    server: Arc<Server>,
  82    receipt: Receipt<R>,
  83    responded: Arc<AtomicBool>,
  84}
  85
  86impl<R: RequestMessage> Response<R> {
  87    fn send(self, payload: R::Response) -> Result<()> {
  88        self.responded.store(true, SeqCst);
  89        self.server.peer.respond(self.receipt, payload)?;
  90        Ok(())
  91    }
  92}
  93
  94pub struct Server {
  95    peer: Arc<Peer>,
  96    pub(crate) store: Mutex<Store>,
  97    app_state: Arc<AppState>,
  98    handlers: HashMap<TypeId, MessageHandler>,
  99    notifications: Option<mpsc::UnboundedSender<()>>,
 100}
 101
 102pub trait Executor: Send + Clone {
 103    type Sleep: Send + Future;
 104    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 105    fn sleep(&self, duration: Duration) -> Self::Sleep;
 106}
 107
 108#[derive(Clone)]
 109pub struct RealExecutor;
 110
 111const MESSAGE_COUNT_PER_PAGE: usize = 100;
 112const MAX_MESSAGE_LEN: usize = 1024;
 113
 114pub(crate) struct StoreGuard<'a> {
 115    guard: MutexGuard<'a, Store>,
 116    _not_send: PhantomData<Rc<()>>,
 117}
 118
 119#[derive(Serialize)]
 120pub struct ServerSnapshot<'a> {
 121    peer: &'a Peer,
 122    #[serde(serialize_with = "serialize_deref")]
 123    store: StoreGuard<'a>,
 124}
 125
 126pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 127where
 128    S: Serializer,
 129    T: Deref<Target = U>,
 130    U: Serialize,
 131{
 132    Serialize::serialize(value.deref(), serializer)
 133}
 134
 135impl Server {
 136    pub fn new(
 137        app_state: Arc<AppState>,
 138        notifications: Option<mpsc::UnboundedSender<()>>,
 139    ) -> Arc<Self> {
 140        let mut server = Self {
 141            peer: Peer::new(),
 142            app_state,
 143            store: Default::default(),
 144            handlers: Default::default(),
 145            notifications,
 146        };
 147
 148        server
 149            .add_request_handler(Server::ping)
 150            .add_request_handler(Server::create_room)
 151            .add_request_handler(Server::join_room)
 152            .add_message_handler(Server::leave_room)
 153            .add_request_handler(Server::call)
 154            .add_request_handler(Server::cancel_call)
 155            .add_message_handler(Server::decline_call)
 156            .add_request_handler(Server::update_participant_location)
 157            .add_request_handler(Server::share_project)
 158            .add_message_handler(Server::unshare_project)
 159            .add_request_handler(Server::join_project)
 160            .add_message_handler(Server::leave_project)
 161            .add_message_handler(Server::update_project)
 162            .add_message_handler(Server::register_project_activity)
 163            .add_request_handler(Server::update_worktree)
 164            .add_message_handler(Server::update_worktree_extensions)
 165            .add_message_handler(Server::start_language_server)
 166            .add_message_handler(Server::update_language_server)
 167            .add_message_handler(Server::update_diagnostic_summary)
 168            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 169            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 170            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 172            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 173            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 174            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 175            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 176            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 177            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 178            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 179            .add_request_handler(
 180                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 181            )
 182            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 183            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 184            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 185            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 186            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 187            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 188            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 189            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 190            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 191            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 192            .add_message_handler(Server::create_buffer_for_peer)
 193            .add_request_handler(Server::update_buffer)
 194            .add_message_handler(Server::update_buffer_file)
 195            .add_message_handler(Server::buffer_reloaded)
 196            .add_message_handler(Server::buffer_saved)
 197            .add_request_handler(Server::save_buffer)
 198            .add_request_handler(Server::get_channels)
 199            .add_request_handler(Server::get_users)
 200            .add_request_handler(Server::fuzzy_search_users)
 201            .add_request_handler(Server::request_contact)
 202            .add_request_handler(Server::remove_contact)
 203            .add_request_handler(Server::respond_to_contact_request)
 204            .add_request_handler(Server::join_channel)
 205            .add_message_handler(Server::leave_channel)
 206            .add_request_handler(Server::send_channel_message)
 207            .add_request_handler(Server::follow)
 208            .add_message_handler(Server::unfollow)
 209            .add_message_handler(Server::update_followers)
 210            .add_request_handler(Server::get_channel_messages)
 211            .add_message_handler(Server::update_diff_base)
 212            .add_request_handler(Server::get_private_user_info);
 213
 214        Arc::new(server)
 215    }
 216
 217    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 218    where
 219        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 220        Fut: 'static + Send + Future<Output = Result<()>>,
 221        M: EnvelopedMessage,
 222    {
 223        let prev_handler = self.handlers.insert(
 224            TypeId::of::<M>(),
 225            Box::new(move |server, envelope| {
 226                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 227                let span = info_span!(
 228                    "handle message",
 229                    payload_type = envelope.payload_type_name()
 230                );
 231                span.in_scope(|| {
 232                    tracing::info!(
 233                        payload_type = envelope.payload_type_name(),
 234                        "message received"
 235                    );
 236                });
 237                let future = (handler)(server, *envelope);
 238                async move {
 239                    if let Err(error) = future.await {
 240                        tracing::error!(%error, "error handling message");
 241                    }
 242                }
 243                .instrument(span)
 244                .boxed()
 245            }),
 246        );
 247        if prev_handler.is_some() {
 248            panic!("registered a handler for the same message twice");
 249        }
 250        self
 251    }
 252
 253    /// Handle a request while holding a lock to the store. This is useful when we're registering
 254    /// a connection but we want to respond on the connection before anybody else can send on it.
 255    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 256    where
 257        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 258        Fut: Send + Future<Output = Result<()>>,
 259        M: RequestMessage,
 260    {
 261        let handler = Arc::new(handler);
 262        self.add_message_handler(move |server, envelope| {
 263            let receipt = envelope.receipt();
 264            let handler = handler.clone();
 265            async move {
 266                let responded = Arc::new(AtomicBool::default());
 267                let response = Response {
 268                    server: server.clone(),
 269                    responded: responded.clone(),
 270                    receipt: envelope.receipt(),
 271                };
 272                match (handler)(server.clone(), envelope, response).await {
 273                    Ok(()) => {
 274                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 275                            Ok(())
 276                        } else {
 277                            Err(anyhow!("handler did not send a response"))?
 278                        }
 279                    }
 280                    Err(error) => {
 281                        server.peer.respond_with_error(
 282                            receipt,
 283                            proto::Error {
 284                                message: error.to_string(),
 285                            },
 286                        )?;
 287                        Err(error)
 288                    }
 289                }
 290            }
 291        })
 292    }
 293
 294    /// Start a long lived task that records which users are active in which projects.
 295    pub fn start_recording_project_activity<E: 'static + Executor>(
 296        self: &Arc<Self>,
 297        interval: Duration,
 298        executor: E,
 299    ) {
 300        executor.spawn_detached({
 301            let this = Arc::downgrade(self);
 302            let executor = executor.clone();
 303            async move {
 304                let mut period_start = OffsetDateTime::now_utc();
 305                let mut active_projects = Vec::<(UserId, ProjectId)>::new();
 306                loop {
 307                    let sleep = executor.sleep(interval);
 308                    sleep.await;
 309                    let this = if let Some(this) = this.upgrade() {
 310                        this
 311                    } else {
 312                        break;
 313                    };
 314
 315                    active_projects.clear();
 316                    active_projects.extend(this.store().await.projects().flat_map(
 317                        |(project_id, project)| {
 318                            project.guests.values().chain([&project.host]).filter_map(
 319                                |collaborator| {
 320                                    if !collaborator.admin
 321                                        && collaborator
 322                                            .last_activity
 323                                            .map_or(false, |activity| activity > period_start)
 324                                    {
 325                                        Some((collaborator.user_id, *project_id))
 326                                    } else {
 327                                        None
 328                                    }
 329                                },
 330                            )
 331                        },
 332                    ));
 333
 334                    let period_end = OffsetDateTime::now_utc();
 335                    this.app_state
 336                        .db
 337                        .record_user_activity(period_start..period_end, &active_projects)
 338                        .await
 339                        .trace_err();
 340                    period_start = period_end;
 341                }
 342            }
 343        });
 344    }
 345
 346    pub fn handle_connection<E: Executor>(
 347        self: &Arc<Self>,
 348        connection: Connection,
 349        address: String,
 350        user: User,
 351        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 352        executor: E,
 353    ) -> impl Future<Output = Result<()>> {
 354        let mut this = self.clone();
 355        let user_id = user.id;
 356        let login = user.github_login;
 357        let span = info_span!("handle connection", %user_id, %login, %address);
 358        async move {
 359            let (connection_id, handle_io, mut incoming_rx) = this
 360                .peer
 361                .add_connection(connection, {
 362                    let executor = executor.clone();
 363                    move |duration| {
 364                        let timer = executor.sleep(duration);
 365                        async move {
 366                            timer.await;
 367                        }
 368                    }
 369                });
 370
 371            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 372
 373            if let Some(send_connection_id) = send_connection_id.as_mut() {
 374                let _ = send_connection_id.send(connection_id).await;
 375            }
 376
 377            if !user.connected_once {
 378                this.peer.send(connection_id, proto::ShowContacts {})?;
 379                this.app_state.db.set_user_connected_once(user_id, true).await?;
 380            }
 381
 382            let (contacts, invite_code) = future::try_join(
 383                this.app_state.db.get_contacts(user_id),
 384                this.app_state.db.get_invite_code_for_user(user_id)
 385            ).await?;
 386
 387            {
 388                let mut store = this.store().await;
 389                let incoming_call = store.add_connection(connection_id, user_id, user.admin);
 390                if let Some(incoming_call) = incoming_call {
 391                    this.peer.send(connection_id, incoming_call)?;
 392                }
 393
 394                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 395
 396                if let Some((code, count)) = invite_code {
 397                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 398                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 399                        count,
 400                    })?;
 401                }
 402            }
 403            this.update_user_contacts(user_id).await?;
 404
 405            let handle_io = handle_io.fuse();
 406            futures::pin_mut!(handle_io);
 407
 408            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 409            // This prevents deadlocks when e.g., client A performs a request to client B and
 410            // client B performs a request to client A. If both clients stop processing further
 411            // messages until their respective request completes, they won't have a chance to
 412            // respond to the other client's request and cause a deadlock.
 413            //
 414            // This arrangement ensures we will attempt to process earlier messages first, but fall
 415            // back to processing messages arrived later in the spirit of making progress.
 416            let mut foreground_message_handlers = FuturesUnordered::new();
 417            loop {
 418                let next_message = incoming_rx.next().fuse();
 419                futures::pin_mut!(next_message);
 420                futures::select_biased! {
 421                    result = handle_io => {
 422                        if let Err(error) = result {
 423                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 424                        }
 425                        break;
 426                    }
 427                    _ = foreground_message_handlers.next() => {}
 428                    message = next_message => {
 429                        if let Some(message) = message {
 430                            let type_name = message.payload_type_name();
 431                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 432                            let span_enter = span.enter();
 433                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 434                                let notifications = this.notifications.clone();
 435                                let is_background = message.is_background();
 436                                let handle_message = (handler)(this.clone(), message);
 437
 438                                drop(span_enter);
 439                                let handle_message = async move {
 440                                    handle_message.await;
 441                                    if let Some(mut notifications) = notifications {
 442                                        let _ = notifications.send(()).await;
 443                                    }
 444                                }.instrument(span);
 445
 446                                if is_background {
 447                                    executor.spawn_detached(handle_message);
 448                                } else {
 449                                    foreground_message_handlers.push(handle_message);
 450                                }
 451                            } else {
 452                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 453                            }
 454                        } else {
 455                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 456                            break;
 457                        }
 458                    }
 459                }
 460            }
 461
 462            drop(foreground_message_handlers);
 463            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 464            if let Err(error) = this.sign_out(connection_id).await {
 465                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 466            }
 467
 468            Ok(())
 469        }.instrument(span)
 470    }
 471
 472    #[instrument(skip(self), err)]
 473    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 474        self.peer.disconnect(connection_id);
 475
 476        let mut projects_to_unshare = Vec::new();
 477        let mut contacts_to_update = HashSet::default();
 478        {
 479            let mut store = self.store().await;
 480            let removed_connection = store.remove_connection(connection_id)?;
 481
 482            for project in removed_connection.hosted_projects {
 483                projects_to_unshare.push(project.id);
 484                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 485                    self.peer.send(
 486                        conn_id,
 487                        proto::UnshareProject {
 488                            project_id: project.id.to_proto(),
 489                        },
 490                    )
 491                });
 492            }
 493
 494            for project in removed_connection.guest_projects {
 495                broadcast(connection_id, project.connection_ids, |conn_id| {
 496                    self.peer.send(
 497                        conn_id,
 498                        proto::RemoveProjectCollaborator {
 499                            project_id: project.id.to_proto(),
 500                            peer_id: connection_id.0,
 501                        },
 502                    )
 503                });
 504            }
 505
 506            for connection_id in removed_connection.canceled_call_connection_ids {
 507                self.peer
 508                    .send(connection_id, proto::CallCanceled {})
 509                    .trace_err();
 510                contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
 511            }
 512
 513            if let Some(room) = removed_connection
 514                .room_id
 515                .and_then(|room_id| store.room(room_id))
 516            {
 517                self.room_updated(room);
 518            }
 519
 520            contacts_to_update.insert(removed_connection.user_id);
 521        };
 522
 523        for user_id in contacts_to_update {
 524            self.update_user_contacts(user_id).await.trace_err();
 525        }
 526
 527        for project_id in projects_to_unshare {
 528            self.app_state
 529                .db
 530                .unregister_project(project_id)
 531                .await
 532                .trace_err();
 533        }
 534
 535        Ok(())
 536    }
 537
 538    pub async fn invite_code_redeemed(
 539        self: &Arc<Self>,
 540        inviter_id: UserId,
 541        invitee_id: UserId,
 542    ) -> Result<()> {
 543        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 544            if let Some(code) = &user.invite_code {
 545                let store = self.store().await;
 546                let invitee_contact = store.contact_for_user(invitee_id, true);
 547                for connection_id in store.connection_ids_for_user(inviter_id) {
 548                    self.peer.send(
 549                        connection_id,
 550                        proto::UpdateContacts {
 551                            contacts: vec![invitee_contact.clone()],
 552                            ..Default::default()
 553                        },
 554                    )?;
 555                    self.peer.send(
 556                        connection_id,
 557                        proto::UpdateInviteInfo {
 558                            url: format!("{}{}", self.app_state.invite_link_prefix, &code),
 559                            count: user.invite_count as u32,
 560                        },
 561                    )?;
 562                }
 563            }
 564        }
 565        Ok(())
 566    }
 567
 568    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 569        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 570            if let Some(invite_code) = &user.invite_code {
 571                let store = self.store().await;
 572                for connection_id in store.connection_ids_for_user(user_id) {
 573                    self.peer.send(
 574                        connection_id,
 575                        proto::UpdateInviteInfo {
 576                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 577                            count: user.invite_count as u32,
 578                        },
 579                    )?;
 580                }
 581            }
 582        }
 583        Ok(())
 584    }
 585
 586    async fn ping(
 587        self: Arc<Server>,
 588        _: TypedEnvelope<proto::Ping>,
 589        response: Response<proto::Ping>,
 590    ) -> Result<()> {
 591        response.send(proto::Ack {})?;
 592        Ok(())
 593    }
 594
 595    async fn create_room(
 596        self: Arc<Server>,
 597        request: TypedEnvelope<proto::CreateRoom>,
 598        response: Response<proto::CreateRoom>,
 599    ) -> Result<()> {
 600        let user_id;
 601        let room_id;
 602        {
 603            let mut store = self.store().await;
 604            user_id = store.user_id_for_connection(request.sender_id)?;
 605            room_id = store.create_room(request.sender_id)?;
 606        }
 607        response.send(proto::CreateRoomResponse { id: room_id })?;
 608        self.update_user_contacts(user_id).await?;
 609        Ok(())
 610    }
 611
 612    async fn join_room(
 613        self: Arc<Server>,
 614        request: TypedEnvelope<proto::JoinRoom>,
 615        response: Response<proto::JoinRoom>,
 616    ) -> Result<()> {
 617        let user_id;
 618        {
 619            let mut store = self.store().await;
 620            user_id = store.user_id_for_connection(request.sender_id)?;
 621            let (room, recipient_connection_ids) =
 622                store.join_room(request.payload.id, request.sender_id)?;
 623            for recipient_id in recipient_connection_ids {
 624                self.peer
 625                    .send(recipient_id, proto::CallCanceled {})
 626                    .trace_err();
 627            }
 628            response.send(proto::JoinRoomResponse {
 629                room: Some(room.clone()),
 630            })?;
 631            self.room_updated(room);
 632        }
 633        self.update_user_contacts(user_id).await?;
 634        Ok(())
 635    }
 636
 637    async fn leave_room(self: Arc<Server>, message: TypedEnvelope<proto::LeaveRoom>) -> Result<()> {
 638        let mut contacts_to_update = HashSet::default();
 639        {
 640            let mut store = self.store().await;
 641            let user_id = store.user_id_for_connection(message.sender_id)?;
 642            let left_room = store.leave_room(message.payload.id, message.sender_id)?;
 643            contacts_to_update.insert(user_id);
 644
 645            for project in left_room.unshared_projects {
 646                for connection_id in project.connection_ids() {
 647                    self.peer.send(
 648                        connection_id,
 649                        proto::UnshareProject {
 650                            project_id: project.id.to_proto(),
 651                        },
 652                    )?;
 653                }
 654            }
 655
 656            for project in left_room.left_projects {
 657                if project.remove_collaborator {
 658                    for connection_id in project.connection_ids {
 659                        self.peer.send(
 660                            connection_id,
 661                            proto::RemoveProjectCollaborator {
 662                                project_id: project.id.to_proto(),
 663                                peer_id: message.sender_id.0,
 664                            },
 665                        )?;
 666                    }
 667
 668                    self.peer.send(
 669                        message.sender_id,
 670                        proto::UnshareProject {
 671                            project_id: project.id.to_proto(),
 672                        },
 673                    )?;
 674                }
 675            }
 676
 677            if let Some(room) = left_room.room {
 678                self.room_updated(room);
 679            }
 680
 681            for connection_id in left_room.canceled_call_connection_ids {
 682                self.peer
 683                    .send(connection_id, proto::CallCanceled {})
 684                    .trace_err();
 685                contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
 686            }
 687        }
 688
 689        for user_id in contacts_to_update {
 690            self.update_user_contacts(user_id).await?;
 691        }
 692
 693        Ok(())
 694    }
 695
 696    async fn call(
 697        self: Arc<Server>,
 698        request: TypedEnvelope<proto::Call>,
 699        response: Response<proto::Call>,
 700    ) -> Result<()> {
 701        let caller_user_id = self
 702            .store()
 703            .await
 704            .user_id_for_connection(request.sender_id)?;
 705        let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id);
 706        let initial_project_id = request
 707            .payload
 708            .initial_project_id
 709            .map(ProjectId::from_proto);
 710        if !self
 711            .app_state
 712            .db
 713            .has_contact(caller_user_id, recipient_user_id)
 714            .await?
 715        {
 716            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 717        }
 718
 719        let room_id = request.payload.room_id;
 720        let mut calls = {
 721            let mut store = self.store().await;
 722            let (room, recipient_connection_ids, incoming_call) = store.call(
 723                room_id,
 724                recipient_user_id,
 725                initial_project_id,
 726                request.sender_id,
 727            )?;
 728            self.room_updated(room);
 729            recipient_connection_ids
 730                .into_iter()
 731                .map(|recipient_connection_id| {
 732                    self.peer
 733                        .request(recipient_connection_id, incoming_call.clone())
 734                })
 735                .collect::<FuturesUnordered<_>>()
 736        };
 737        self.update_user_contacts(recipient_user_id).await?;
 738
 739        while let Some(call_response) = calls.next().await {
 740            match call_response.as_ref() {
 741                Ok(_) => {
 742                    response.send(proto::Ack {})?;
 743                    return Ok(());
 744                }
 745                Err(_) => {
 746                    call_response.trace_err();
 747                }
 748            }
 749        }
 750
 751        {
 752            let mut store = self.store().await;
 753            let room = store.call_failed(room_id, recipient_user_id)?;
 754            self.room_updated(&room);
 755        }
 756        self.update_user_contacts(recipient_user_id).await?;
 757
 758        Err(anyhow!("failed to ring call recipient"))?
 759    }
 760
 761    async fn cancel_call(
 762        self: Arc<Server>,
 763        request: TypedEnvelope<proto::CancelCall>,
 764        response: Response<proto::CancelCall>,
 765    ) -> Result<()> {
 766        let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id);
 767        {
 768            let mut store = self.store().await;
 769            let (room, recipient_connection_ids) = store.cancel_call(
 770                request.payload.room_id,
 771                recipient_user_id,
 772                request.sender_id,
 773            )?;
 774            for recipient_id in recipient_connection_ids {
 775                self.peer
 776                    .send(recipient_id, proto::CallCanceled {})
 777                    .trace_err();
 778            }
 779            self.room_updated(room);
 780            response.send(proto::Ack {})?;
 781        }
 782        self.update_user_contacts(recipient_user_id).await?;
 783        Ok(())
 784    }
 785
 786    async fn decline_call(
 787        self: Arc<Server>,
 788        message: TypedEnvelope<proto::DeclineCall>,
 789    ) -> Result<()> {
 790        let recipient_user_id;
 791        {
 792            let mut store = self.store().await;
 793            recipient_user_id = store.user_id_for_connection(message.sender_id)?;
 794            let (room, recipient_connection_ids) =
 795                store.decline_call(message.payload.room_id, message.sender_id)?;
 796            for recipient_id in recipient_connection_ids {
 797                self.peer
 798                    .send(recipient_id, proto::CallCanceled {})
 799                    .trace_err();
 800            }
 801            self.room_updated(room);
 802        }
 803        self.update_user_contacts(recipient_user_id).await?;
 804        Ok(())
 805    }
 806
 807    async fn update_participant_location(
 808        self: Arc<Server>,
 809        request: TypedEnvelope<proto::UpdateParticipantLocation>,
 810        response: Response<proto::UpdateParticipantLocation>,
 811    ) -> Result<()> {
 812        let room_id = request.payload.room_id;
 813        let location = request
 814            .payload
 815            .location
 816            .ok_or_else(|| anyhow!("invalid location"))?;
 817        let mut store = self.store().await;
 818        let room = store.update_participant_location(room_id, location, request.sender_id)?;
 819        self.room_updated(room);
 820        response.send(proto::Ack {})?;
 821        Ok(())
 822    }
 823
 824    fn room_updated(&self, room: &proto::Room) {
 825        for participant in &room.participants {
 826            self.peer
 827                .send(
 828                    ConnectionId(participant.peer_id),
 829                    proto::RoomUpdated {
 830                        room: Some(room.clone()),
 831                    },
 832                )
 833                .trace_err();
 834        }
 835    }
 836
 837    async fn share_project(
 838        self: Arc<Server>,
 839        request: TypedEnvelope<proto::ShareProject>,
 840        response: Response<proto::ShareProject>,
 841    ) -> Result<()> {
 842        let user_id = self
 843            .store()
 844            .await
 845            .user_id_for_connection(request.sender_id)?;
 846        let project_id = self.app_state.db.register_project(user_id).await?;
 847        let mut store = self.store().await;
 848        let room = store.share_project(
 849            request.payload.room_id,
 850            project_id,
 851            request.payload.worktrees,
 852            request.sender_id,
 853        )?;
 854        response.send(proto::ShareProjectResponse {
 855            project_id: project_id.to_proto(),
 856        })?;
 857        self.room_updated(room);
 858
 859        Ok(())
 860    }
 861
 862    async fn unshare_project(
 863        self: Arc<Server>,
 864        message: TypedEnvelope<proto::UnshareProject>,
 865    ) -> Result<()> {
 866        let project_id = ProjectId::from_proto(message.payload.project_id);
 867        let mut store = self.store().await;
 868        let (room, project) = store.unshare_project(project_id, message.sender_id)?;
 869        broadcast(
 870            message.sender_id,
 871            project.guest_connection_ids(),
 872            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 873        );
 874        self.room_updated(room);
 875
 876        Ok(())
 877    }
 878
 879    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 880        let contacts = self.app_state.db.get_contacts(user_id).await?;
 881        let store = self.store().await;
 882        let updated_contact = store.contact_for_user(user_id, false);
 883        for contact in contacts {
 884            if let db::Contact::Accepted {
 885                user_id: contact_user_id,
 886                ..
 887            } = contact
 888            {
 889                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 890                    self.peer
 891                        .send(
 892                            contact_conn_id,
 893                            proto::UpdateContacts {
 894                                contacts: vec![updated_contact.clone()],
 895                                remove_contacts: Default::default(),
 896                                incoming_requests: Default::default(),
 897                                remove_incoming_requests: Default::default(),
 898                                outgoing_requests: Default::default(),
 899                                remove_outgoing_requests: Default::default(),
 900                            },
 901                        )
 902                        .trace_err();
 903                }
 904            }
 905        }
 906        Ok(())
 907    }
 908
 909    async fn join_project(
 910        self: Arc<Server>,
 911        request: TypedEnvelope<proto::JoinProject>,
 912        response: Response<proto::JoinProject>,
 913    ) -> Result<()> {
 914        let project_id = ProjectId::from_proto(request.payload.project_id);
 915
 916        let host_user_id;
 917        let guest_user_id;
 918        let host_connection_id;
 919        {
 920            let state = self.store().await;
 921            let project = state.project(project_id)?;
 922            host_user_id = project.host.user_id;
 923            host_connection_id = project.host_connection_id;
 924            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 925        };
 926
 927        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 928
 929        let mut store = self.store().await;
 930        let (project, replica_id) = store.join_project(request.sender_id, project_id)?;
 931        let peer_count = project.guests.len();
 932        let mut collaborators = Vec::with_capacity(peer_count);
 933        collaborators.push(proto::Collaborator {
 934            peer_id: project.host_connection_id.0,
 935            replica_id: 0,
 936            user_id: project.host.user_id.to_proto(),
 937        });
 938        let worktrees = project
 939            .worktrees
 940            .iter()
 941            .map(|(id, worktree)| proto::WorktreeMetadata {
 942                id: *id,
 943                root_name: worktree.root_name.clone(),
 944                visible: worktree.visible,
 945                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
 946            })
 947            .collect::<Vec<_>>();
 948
 949        // Add all guests other than the requesting user's own connections as collaborators
 950        for (guest_conn_id, guest) in &project.guests {
 951            if request.sender_id != *guest_conn_id {
 952                collaborators.push(proto::Collaborator {
 953                    peer_id: guest_conn_id.0,
 954                    replica_id: guest.replica_id as u32,
 955                    user_id: guest.user_id.to_proto(),
 956                });
 957            }
 958        }
 959
 960        for conn_id in project.connection_ids() {
 961            if conn_id != request.sender_id {
 962                self.peer.send(
 963                    conn_id,
 964                    proto::AddProjectCollaborator {
 965                        project_id: project_id.to_proto(),
 966                        collaborator: Some(proto::Collaborator {
 967                            peer_id: request.sender_id.0,
 968                            replica_id: replica_id as u32,
 969                            user_id: guest_user_id.to_proto(),
 970                        }),
 971                    },
 972                )?;
 973            }
 974        }
 975
 976        // First, we send the metadata associated with each worktree.
 977        response.send(proto::JoinProjectResponse {
 978            worktrees: worktrees.clone(),
 979            replica_id: replica_id as u32,
 980            collaborators: collaborators.clone(),
 981            language_servers: project.language_servers.clone(),
 982        })?;
 983
 984        for (worktree_id, worktree) in &project.worktrees {
 985            #[cfg(any(test, feature = "test-support"))]
 986            const MAX_CHUNK_SIZE: usize = 2;
 987            #[cfg(not(any(test, feature = "test-support")))]
 988            const MAX_CHUNK_SIZE: usize = 256;
 989
 990            // Stream this worktree's entries.
 991            let message = proto::UpdateWorktree {
 992                project_id: project_id.to_proto(),
 993                worktree_id: *worktree_id,
 994                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
 995                root_name: worktree.root_name.clone(),
 996                updated_entries: worktree.entries.values().cloned().collect(),
 997                removed_entries: Default::default(),
 998                scan_id: worktree.scan_id,
 999                is_last_update: worktree.is_complete,
1000            };
1001            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1002                self.peer.send(request.sender_id, update.clone())?;
1003            }
1004
1005            // Stream this worktree's diagnostics.
1006            for summary in worktree.diagnostic_summaries.values() {
1007                self.peer.send(
1008                    request.sender_id,
1009                    proto::UpdateDiagnosticSummary {
1010                        project_id: project_id.to_proto(),
1011                        worktree_id: *worktree_id,
1012                        summary: Some(summary.clone()),
1013                    },
1014                )?;
1015            }
1016        }
1017
1018        for language_server in &project.language_servers {
1019            self.peer.send(
1020                request.sender_id,
1021                proto::UpdateLanguageServer {
1022                    project_id: project_id.to_proto(),
1023                    language_server_id: language_server.id,
1024                    variant: Some(
1025                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1026                            proto::LspDiskBasedDiagnosticsUpdated {},
1027                        ),
1028                    ),
1029                },
1030            )?;
1031        }
1032
1033        Ok(())
1034    }
1035
1036    async fn leave_project(
1037        self: Arc<Server>,
1038        request: TypedEnvelope<proto::LeaveProject>,
1039    ) -> Result<()> {
1040        let sender_id = request.sender_id;
1041        let project_id = ProjectId::from_proto(request.payload.project_id);
1042        let project;
1043        {
1044            let mut store = self.store().await;
1045            project = store.leave_project(project_id, sender_id)?;
1046            tracing::info!(
1047                %project_id,
1048                host_user_id = %project.host_user_id,
1049                host_connection_id = %project.host_connection_id,
1050                "leave project"
1051            );
1052
1053            if project.remove_collaborator {
1054                broadcast(sender_id, project.connection_ids, |conn_id| {
1055                    self.peer.send(
1056                        conn_id,
1057                        proto::RemoveProjectCollaborator {
1058                            project_id: project_id.to_proto(),
1059                            peer_id: sender_id.0,
1060                        },
1061                    )
1062                });
1063            }
1064        }
1065
1066        Ok(())
1067    }
1068
1069    async fn update_project(
1070        self: Arc<Server>,
1071        request: TypedEnvelope<proto::UpdateProject>,
1072    ) -> Result<()> {
1073        let project_id = ProjectId::from_proto(request.payload.project_id);
1074        {
1075            let mut state = self.store().await;
1076            let guest_connection_ids = state
1077                .read_project(project_id, request.sender_id)?
1078                .guest_connection_ids();
1079            let room =
1080                state.update_project(project_id, &request.payload.worktrees, request.sender_id)?;
1081            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
1082                self.peer
1083                    .forward_send(request.sender_id, connection_id, request.payload.clone())
1084            });
1085            self.room_updated(room);
1086        };
1087
1088        Ok(())
1089    }
1090
1091    async fn register_project_activity(
1092        self: Arc<Server>,
1093        request: TypedEnvelope<proto::RegisterProjectActivity>,
1094    ) -> Result<()> {
1095        self.store().await.register_project_activity(
1096            ProjectId::from_proto(request.payload.project_id),
1097            request.sender_id,
1098        )?;
1099        Ok(())
1100    }
1101
1102    async fn update_worktree(
1103        self: Arc<Server>,
1104        request: TypedEnvelope<proto::UpdateWorktree>,
1105        response: Response<proto::UpdateWorktree>,
1106    ) -> Result<()> {
1107        let project_id = ProjectId::from_proto(request.payload.project_id);
1108        let worktree_id = request.payload.worktree_id;
1109        let connection_ids = self.store().await.update_worktree(
1110            request.sender_id,
1111            project_id,
1112            worktree_id,
1113            &request.payload.root_name,
1114            &request.payload.removed_entries,
1115            &request.payload.updated_entries,
1116            request.payload.scan_id,
1117            request.payload.is_last_update,
1118        )?;
1119
1120        broadcast(request.sender_id, connection_ids, |connection_id| {
1121            self.peer
1122                .forward_send(request.sender_id, connection_id, request.payload.clone())
1123        });
1124        response.send(proto::Ack {})?;
1125        Ok(())
1126    }
1127
1128    async fn update_worktree_extensions(
1129        self: Arc<Server>,
1130        request: TypedEnvelope<proto::UpdateWorktreeExtensions>,
1131    ) -> Result<()> {
1132        let project_id = ProjectId::from_proto(request.payload.project_id);
1133        let worktree_id = request.payload.worktree_id;
1134        let extensions = request
1135            .payload
1136            .extensions
1137            .into_iter()
1138            .zip(request.payload.counts)
1139            .collect();
1140        self.app_state
1141            .db
1142            .update_worktree_extensions(project_id, worktree_id, extensions)
1143            .await?;
1144        Ok(())
1145    }
1146
1147    async fn update_diagnostic_summary(
1148        self: Arc<Server>,
1149        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
1150    ) -> Result<()> {
1151        let summary = request
1152            .payload
1153            .summary
1154            .clone()
1155            .ok_or_else(|| anyhow!("invalid summary"))?;
1156        let receiver_ids = self.store().await.update_diagnostic_summary(
1157            ProjectId::from_proto(request.payload.project_id),
1158            request.payload.worktree_id,
1159            request.sender_id,
1160            summary,
1161        )?;
1162
1163        broadcast(request.sender_id, receiver_ids, |connection_id| {
1164            self.peer
1165                .forward_send(request.sender_id, connection_id, request.payload.clone())
1166        });
1167        Ok(())
1168    }
1169
1170    async fn start_language_server(
1171        self: Arc<Server>,
1172        request: TypedEnvelope<proto::StartLanguageServer>,
1173    ) -> Result<()> {
1174        let receiver_ids = self.store().await.start_language_server(
1175            ProjectId::from_proto(request.payload.project_id),
1176            request.sender_id,
1177            request
1178                .payload
1179                .server
1180                .clone()
1181                .ok_or_else(|| anyhow!("invalid language server"))?,
1182        )?;
1183        broadcast(request.sender_id, receiver_ids, |connection_id| {
1184            self.peer
1185                .forward_send(request.sender_id, connection_id, request.payload.clone())
1186        });
1187        Ok(())
1188    }
1189
1190    async fn update_language_server(
1191        self: Arc<Server>,
1192        request: TypedEnvelope<proto::UpdateLanguageServer>,
1193    ) -> Result<()> {
1194        let receiver_ids = self.store().await.project_connection_ids(
1195            ProjectId::from_proto(request.payload.project_id),
1196            request.sender_id,
1197        )?;
1198        broadcast(request.sender_id, receiver_ids, |connection_id| {
1199            self.peer
1200                .forward_send(request.sender_id, connection_id, request.payload.clone())
1201        });
1202        Ok(())
1203    }
1204
1205    async fn forward_project_request<T>(
1206        self: Arc<Server>,
1207        request: TypedEnvelope<T>,
1208        response: Response<T>,
1209    ) -> Result<()>
1210    where
1211        T: EntityMessage + RequestMessage,
1212    {
1213        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1214        let host_connection_id = self
1215            .store()
1216            .await
1217            .read_project(project_id, request.sender_id)?
1218            .host_connection_id;
1219        let payload = self
1220            .peer
1221            .forward_request(request.sender_id, host_connection_id, request.payload)
1222            .await?;
1223
1224        // Ensure project still exists by the time we get the response from the host.
1225        self.store()
1226            .await
1227            .read_project(project_id, request.sender_id)?;
1228
1229        response.send(payload)?;
1230        Ok(())
1231    }
1232
1233    async fn save_buffer(
1234        self: Arc<Server>,
1235        request: TypedEnvelope<proto::SaveBuffer>,
1236        response: Response<proto::SaveBuffer>,
1237    ) -> Result<()> {
1238        let project_id = ProjectId::from_proto(request.payload.project_id);
1239        let host = self
1240            .store()
1241            .await
1242            .read_project(project_id, request.sender_id)?
1243            .host_connection_id;
1244        let response_payload = self
1245            .peer
1246            .forward_request(request.sender_id, host, request.payload.clone())
1247            .await?;
1248
1249        let mut guests = self
1250            .store()
1251            .await
1252            .read_project(project_id, request.sender_id)?
1253            .connection_ids();
1254        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1255        broadcast(host, guests, |conn_id| {
1256            self.peer
1257                .forward_send(host, conn_id, response_payload.clone())
1258        });
1259        response.send(response_payload)?;
1260        Ok(())
1261    }
1262
1263    async fn create_buffer_for_peer(
1264        self: Arc<Server>,
1265        request: TypedEnvelope<proto::CreateBufferForPeer>,
1266    ) -> Result<()> {
1267        self.peer.forward_send(
1268            request.sender_id,
1269            ConnectionId(request.payload.peer_id),
1270            request.payload,
1271        )?;
1272        Ok(())
1273    }
1274
1275    async fn update_buffer(
1276        self: Arc<Server>,
1277        request: TypedEnvelope<proto::UpdateBuffer>,
1278        response: Response<proto::UpdateBuffer>,
1279    ) -> Result<()> {
1280        let project_id = ProjectId::from_proto(request.payload.project_id);
1281        let receiver_ids = {
1282            let mut store = self.store().await;
1283            store.register_project_activity(project_id, request.sender_id)?;
1284            store.project_connection_ids(project_id, request.sender_id)?
1285        };
1286
1287        broadcast(request.sender_id, receiver_ids, |connection_id| {
1288            self.peer
1289                .forward_send(request.sender_id, connection_id, request.payload.clone())
1290        });
1291        response.send(proto::Ack {})?;
1292        Ok(())
1293    }
1294
1295    async fn update_buffer_file(
1296        self: Arc<Server>,
1297        request: TypedEnvelope<proto::UpdateBufferFile>,
1298    ) -> Result<()> {
1299        let receiver_ids = self.store().await.project_connection_ids(
1300            ProjectId::from_proto(request.payload.project_id),
1301            request.sender_id,
1302        )?;
1303        broadcast(request.sender_id, receiver_ids, |connection_id| {
1304            self.peer
1305                .forward_send(request.sender_id, connection_id, request.payload.clone())
1306        });
1307        Ok(())
1308    }
1309
1310    async fn buffer_reloaded(
1311        self: Arc<Server>,
1312        request: TypedEnvelope<proto::BufferReloaded>,
1313    ) -> Result<()> {
1314        let receiver_ids = self.store().await.project_connection_ids(
1315            ProjectId::from_proto(request.payload.project_id),
1316            request.sender_id,
1317        )?;
1318        broadcast(request.sender_id, receiver_ids, |connection_id| {
1319            self.peer
1320                .forward_send(request.sender_id, connection_id, request.payload.clone())
1321        });
1322        Ok(())
1323    }
1324
1325    async fn buffer_saved(
1326        self: Arc<Server>,
1327        request: TypedEnvelope<proto::BufferSaved>,
1328    ) -> Result<()> {
1329        let receiver_ids = self.store().await.project_connection_ids(
1330            ProjectId::from_proto(request.payload.project_id),
1331            request.sender_id,
1332        )?;
1333        broadcast(request.sender_id, receiver_ids, |connection_id| {
1334            self.peer
1335                .forward_send(request.sender_id, connection_id, request.payload.clone())
1336        });
1337        Ok(())
1338    }
1339
1340    async fn follow(
1341        self: Arc<Self>,
1342        request: TypedEnvelope<proto::Follow>,
1343        response: Response<proto::Follow>,
1344    ) -> Result<()> {
1345        let project_id = ProjectId::from_proto(request.payload.project_id);
1346        let leader_id = ConnectionId(request.payload.leader_id);
1347        let follower_id = request.sender_id;
1348        {
1349            let mut store = self.store().await;
1350            if !store
1351                .project_connection_ids(project_id, follower_id)?
1352                .contains(&leader_id)
1353            {
1354                Err(anyhow!("no such peer"))?;
1355            }
1356
1357            store.register_project_activity(project_id, follower_id)?;
1358        }
1359
1360        let mut response_payload = self
1361            .peer
1362            .forward_request(request.sender_id, leader_id, request.payload)
1363            .await?;
1364        response_payload
1365            .views
1366            .retain(|view| view.leader_id != Some(follower_id.0));
1367        response.send(response_payload)?;
1368        Ok(())
1369    }
1370
1371    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1372        let project_id = ProjectId::from_proto(request.payload.project_id);
1373        let leader_id = ConnectionId(request.payload.leader_id);
1374        let mut store = self.store().await;
1375        if !store
1376            .project_connection_ids(project_id, request.sender_id)?
1377            .contains(&leader_id)
1378        {
1379            Err(anyhow!("no such peer"))?;
1380        }
1381        store.register_project_activity(project_id, request.sender_id)?;
1382        self.peer
1383            .forward_send(request.sender_id, leader_id, request.payload)?;
1384        Ok(())
1385    }
1386
1387    async fn update_followers(
1388        self: Arc<Self>,
1389        request: TypedEnvelope<proto::UpdateFollowers>,
1390    ) -> Result<()> {
1391        let project_id = ProjectId::from_proto(request.payload.project_id);
1392        let mut store = self.store().await;
1393        store.register_project_activity(project_id, request.sender_id)?;
1394        let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
1395        let leader_id = request
1396            .payload
1397            .variant
1398            .as_ref()
1399            .and_then(|variant| match variant {
1400                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1401                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1402                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1403            });
1404        for follower_id in &request.payload.follower_ids {
1405            let follower_id = ConnectionId(*follower_id);
1406            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1407                self.peer
1408                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1409            }
1410        }
1411        Ok(())
1412    }
1413
1414    async fn get_channels(
1415        self: Arc<Server>,
1416        request: TypedEnvelope<proto::GetChannels>,
1417        response: Response<proto::GetChannels>,
1418    ) -> Result<()> {
1419        let user_id = self
1420            .store()
1421            .await
1422            .user_id_for_connection(request.sender_id)?;
1423        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1424        response.send(proto::GetChannelsResponse {
1425            channels: channels
1426                .into_iter()
1427                .map(|chan| proto::Channel {
1428                    id: chan.id.to_proto(),
1429                    name: chan.name,
1430                })
1431                .collect(),
1432        })?;
1433        Ok(())
1434    }
1435
1436    async fn get_users(
1437        self: Arc<Server>,
1438        request: TypedEnvelope<proto::GetUsers>,
1439        response: Response<proto::GetUsers>,
1440    ) -> Result<()> {
1441        let user_ids = request
1442            .payload
1443            .user_ids
1444            .into_iter()
1445            .map(UserId::from_proto)
1446            .collect();
1447        let users = self
1448            .app_state
1449            .db
1450            .get_users_by_ids(user_ids)
1451            .await?
1452            .into_iter()
1453            .map(|user| proto::User {
1454                id: user.id.to_proto(),
1455                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1456                github_login: user.github_login,
1457            })
1458            .collect();
1459        response.send(proto::UsersResponse { users })?;
1460        Ok(())
1461    }
1462
1463    async fn fuzzy_search_users(
1464        self: Arc<Server>,
1465        request: TypedEnvelope<proto::FuzzySearchUsers>,
1466        response: Response<proto::FuzzySearchUsers>,
1467    ) -> Result<()> {
1468        let user_id = self
1469            .store()
1470            .await
1471            .user_id_for_connection(request.sender_id)?;
1472        let query = request.payload.query;
1473        let db = &self.app_state.db;
1474        let users = match query.len() {
1475            0 => vec![],
1476            1 | 2 => db
1477                .get_user_by_github_account(&query, None)
1478                .await?
1479                .into_iter()
1480                .collect(),
1481            _ => db.fuzzy_search_users(&query, 10).await?,
1482        };
1483        let users = users
1484            .into_iter()
1485            .filter(|user| user.id != user_id)
1486            .map(|user| proto::User {
1487                id: user.id.to_proto(),
1488                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1489                github_login: user.github_login,
1490            })
1491            .collect();
1492        response.send(proto::UsersResponse { users })?;
1493        Ok(())
1494    }
1495
1496    async fn request_contact(
1497        self: Arc<Server>,
1498        request: TypedEnvelope<proto::RequestContact>,
1499        response: Response<proto::RequestContact>,
1500    ) -> Result<()> {
1501        let requester_id = self
1502            .store()
1503            .await
1504            .user_id_for_connection(request.sender_id)?;
1505        let responder_id = UserId::from_proto(request.payload.responder_id);
1506        if requester_id == responder_id {
1507            return Err(anyhow!("cannot add yourself as a contact"))?;
1508        }
1509
1510        self.app_state
1511            .db
1512            .send_contact_request(requester_id, responder_id)
1513            .await?;
1514
1515        // Update outgoing contact requests of requester
1516        let mut update = proto::UpdateContacts::default();
1517        update.outgoing_requests.push(responder_id.to_proto());
1518        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1519            self.peer.send(connection_id, update.clone())?;
1520        }
1521
1522        // Update incoming contact requests of responder
1523        let mut update = proto::UpdateContacts::default();
1524        update
1525            .incoming_requests
1526            .push(proto::IncomingContactRequest {
1527                requester_id: requester_id.to_proto(),
1528                should_notify: true,
1529            });
1530        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1531            self.peer.send(connection_id, update.clone())?;
1532        }
1533
1534        response.send(proto::Ack {})?;
1535        Ok(())
1536    }
1537
1538    async fn respond_to_contact_request(
1539        self: Arc<Server>,
1540        request: TypedEnvelope<proto::RespondToContactRequest>,
1541        response: Response<proto::RespondToContactRequest>,
1542    ) -> Result<()> {
1543        let responder_id = self
1544            .store()
1545            .await
1546            .user_id_for_connection(request.sender_id)?;
1547        let requester_id = UserId::from_proto(request.payload.requester_id);
1548        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1549            self.app_state
1550                .db
1551                .dismiss_contact_notification(responder_id, requester_id)
1552                .await?;
1553        } else {
1554            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1555            self.app_state
1556                .db
1557                .respond_to_contact_request(responder_id, requester_id, accept)
1558                .await?;
1559
1560            let store = self.store().await;
1561            // Update responder with new contact
1562            let mut update = proto::UpdateContacts::default();
1563            if accept {
1564                update
1565                    .contacts
1566                    .push(store.contact_for_user(requester_id, false));
1567            }
1568            update
1569                .remove_incoming_requests
1570                .push(requester_id.to_proto());
1571            for connection_id in store.connection_ids_for_user(responder_id) {
1572                self.peer.send(connection_id, update.clone())?;
1573            }
1574
1575            // Update requester with new contact
1576            let mut update = proto::UpdateContacts::default();
1577            if accept {
1578                update
1579                    .contacts
1580                    .push(store.contact_for_user(responder_id, true));
1581            }
1582            update
1583                .remove_outgoing_requests
1584                .push(responder_id.to_proto());
1585            for connection_id in store.connection_ids_for_user(requester_id) {
1586                self.peer.send(connection_id, update.clone())?;
1587            }
1588        }
1589
1590        response.send(proto::Ack {})?;
1591        Ok(())
1592    }
1593
1594    async fn remove_contact(
1595        self: Arc<Server>,
1596        request: TypedEnvelope<proto::RemoveContact>,
1597        response: Response<proto::RemoveContact>,
1598    ) -> Result<()> {
1599        let requester_id = self
1600            .store()
1601            .await
1602            .user_id_for_connection(request.sender_id)?;
1603        let responder_id = UserId::from_proto(request.payload.user_id);
1604        self.app_state
1605            .db
1606            .remove_contact(requester_id, responder_id)
1607            .await?;
1608
1609        // Update outgoing contact requests of requester
1610        let mut update = proto::UpdateContacts::default();
1611        update
1612            .remove_outgoing_requests
1613            .push(responder_id.to_proto());
1614        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1615            self.peer.send(connection_id, update.clone())?;
1616        }
1617
1618        // Update incoming contact requests of responder
1619        let mut update = proto::UpdateContacts::default();
1620        update
1621            .remove_incoming_requests
1622            .push(requester_id.to_proto());
1623        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1624            self.peer.send(connection_id, update.clone())?;
1625        }
1626
1627        response.send(proto::Ack {})?;
1628        Ok(())
1629    }
1630
1631    async fn join_channel(
1632        self: Arc<Self>,
1633        request: TypedEnvelope<proto::JoinChannel>,
1634        response: Response<proto::JoinChannel>,
1635    ) -> Result<()> {
1636        let user_id = self
1637            .store()
1638            .await
1639            .user_id_for_connection(request.sender_id)?;
1640        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1641        if !self
1642            .app_state
1643            .db
1644            .can_user_access_channel(user_id, channel_id)
1645            .await?
1646        {
1647            Err(anyhow!("access denied"))?;
1648        }
1649
1650        self.store()
1651            .await
1652            .join_channel(request.sender_id, channel_id);
1653        let messages = self
1654            .app_state
1655            .db
1656            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1657            .await?
1658            .into_iter()
1659            .map(|msg| proto::ChannelMessage {
1660                id: msg.id.to_proto(),
1661                body: msg.body,
1662                timestamp: msg.sent_at.unix_timestamp() as u64,
1663                sender_id: msg.sender_id.to_proto(),
1664                nonce: Some(msg.nonce.as_u128().into()),
1665            })
1666            .collect::<Vec<_>>();
1667        response.send(proto::JoinChannelResponse {
1668            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1669            messages,
1670        })?;
1671        Ok(())
1672    }
1673
1674    async fn leave_channel(
1675        self: Arc<Self>,
1676        request: TypedEnvelope<proto::LeaveChannel>,
1677    ) -> Result<()> {
1678        let user_id = self
1679            .store()
1680            .await
1681            .user_id_for_connection(request.sender_id)?;
1682        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1683        if !self
1684            .app_state
1685            .db
1686            .can_user_access_channel(user_id, channel_id)
1687            .await?
1688        {
1689            Err(anyhow!("access denied"))?;
1690        }
1691
1692        self.store()
1693            .await
1694            .leave_channel(request.sender_id, channel_id);
1695
1696        Ok(())
1697    }
1698
1699    async fn send_channel_message(
1700        self: Arc<Self>,
1701        request: TypedEnvelope<proto::SendChannelMessage>,
1702        response: Response<proto::SendChannelMessage>,
1703    ) -> Result<()> {
1704        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1705        let user_id;
1706        let connection_ids;
1707        {
1708            let state = self.store().await;
1709            user_id = state.user_id_for_connection(request.sender_id)?;
1710            connection_ids = state.channel_connection_ids(channel_id)?;
1711        }
1712
1713        // Validate the message body.
1714        let body = request.payload.body.trim().to_string();
1715        if body.len() > MAX_MESSAGE_LEN {
1716            return Err(anyhow!("message is too long"))?;
1717        }
1718        if body.is_empty() {
1719            return Err(anyhow!("message can't be blank"))?;
1720        }
1721
1722        let timestamp = OffsetDateTime::now_utc();
1723        let nonce = request
1724            .payload
1725            .nonce
1726            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1727
1728        let message_id = self
1729            .app_state
1730            .db
1731            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1732            .await?
1733            .to_proto();
1734        let message = proto::ChannelMessage {
1735            sender_id: user_id.to_proto(),
1736            id: message_id,
1737            body,
1738            timestamp: timestamp.unix_timestamp() as u64,
1739            nonce: Some(nonce),
1740        };
1741        broadcast(request.sender_id, connection_ids, |conn_id| {
1742            self.peer.send(
1743                conn_id,
1744                proto::ChannelMessageSent {
1745                    channel_id: channel_id.to_proto(),
1746                    message: Some(message.clone()),
1747                },
1748            )
1749        });
1750        response.send(proto::SendChannelMessageResponse {
1751            message: Some(message),
1752        })?;
1753        Ok(())
1754    }
1755
1756    async fn get_channel_messages(
1757        self: Arc<Self>,
1758        request: TypedEnvelope<proto::GetChannelMessages>,
1759        response: Response<proto::GetChannelMessages>,
1760    ) -> Result<()> {
1761        let user_id = self
1762            .store()
1763            .await
1764            .user_id_for_connection(request.sender_id)?;
1765        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1766        if !self
1767            .app_state
1768            .db
1769            .can_user_access_channel(user_id, channel_id)
1770            .await?
1771        {
1772            Err(anyhow!("access denied"))?;
1773        }
1774
1775        let messages = self
1776            .app_state
1777            .db
1778            .get_channel_messages(
1779                channel_id,
1780                MESSAGE_COUNT_PER_PAGE,
1781                Some(MessageId::from_proto(request.payload.before_message_id)),
1782            )
1783            .await?
1784            .into_iter()
1785            .map(|msg| proto::ChannelMessage {
1786                id: msg.id.to_proto(),
1787                body: msg.body,
1788                timestamp: msg.sent_at.unix_timestamp() as u64,
1789                sender_id: msg.sender_id.to_proto(),
1790                nonce: Some(msg.nonce.as_u128().into()),
1791            })
1792            .collect::<Vec<_>>();
1793        response.send(proto::GetChannelMessagesResponse {
1794            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1795            messages,
1796        })?;
1797        Ok(())
1798    }
1799
1800    async fn update_diff_base(
1801        self: Arc<Server>,
1802        request: TypedEnvelope<proto::UpdateDiffBase>,
1803    ) -> Result<()> {
1804        let receiver_ids = self.store().await.project_connection_ids(
1805            ProjectId::from_proto(request.payload.project_id),
1806            request.sender_id,
1807        )?;
1808        broadcast(request.sender_id, receiver_ids, |connection_id| {
1809            self.peer
1810                .forward_send(request.sender_id, connection_id, request.payload.clone())
1811        });
1812        Ok(())
1813    }
1814
1815    async fn get_private_user_info(
1816        self: Arc<Self>,
1817        request: TypedEnvelope<proto::GetPrivateUserInfo>,
1818        response: Response<proto::GetPrivateUserInfo>,
1819    ) -> Result<()> {
1820        let user_id = self
1821            .store()
1822            .await
1823            .user_id_for_connection(request.sender_id)?;
1824        let metrics_id = self.app_state.db.get_user_metrics_id(user_id).await?;
1825        let user = self
1826            .app_state
1827            .db
1828            .get_user_by_id(user_id)
1829            .await?
1830            .ok_or_else(|| anyhow!("user not found"))?;
1831        response.send(proto::GetPrivateUserInfoResponse {
1832            metrics_id,
1833            staff: user.admin,
1834        })?;
1835        Ok(())
1836    }
1837
1838    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1839        #[cfg(test)]
1840        tokio::task::yield_now().await;
1841        let guard = self.store.lock().await;
1842        #[cfg(test)]
1843        tokio::task::yield_now().await;
1844        StoreGuard {
1845            guard,
1846            _not_send: PhantomData,
1847        }
1848    }
1849
1850    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1851        ServerSnapshot {
1852            store: self.store().await,
1853            peer: &self.peer,
1854        }
1855    }
1856}
1857
1858impl<'a> Deref for StoreGuard<'a> {
1859    type Target = Store;
1860
1861    fn deref(&self) -> &Self::Target {
1862        &*self.guard
1863    }
1864}
1865
1866impl<'a> DerefMut for StoreGuard<'a> {
1867    fn deref_mut(&mut self) -> &mut Self::Target {
1868        &mut *self.guard
1869    }
1870}
1871
1872impl<'a> Drop for StoreGuard<'a> {
1873    fn drop(&mut self) {
1874        #[cfg(test)]
1875        self.check_invariants();
1876    }
1877}
1878
1879impl Executor for RealExecutor {
1880    type Sleep = Sleep;
1881
1882    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1883        tokio::task::spawn(future);
1884    }
1885
1886    fn sleep(&self, duration: Duration) -> Self::Sleep {
1887        tokio::time::sleep(duration)
1888    }
1889}
1890
1891fn broadcast<F>(
1892    sender_id: ConnectionId,
1893    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1894    mut f: F,
1895) where
1896    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1897{
1898    for receiver_id in receiver_ids {
1899        if receiver_id != sender_id {
1900            f(receiver_id).trace_err();
1901        }
1902    }
1903}
1904
1905lazy_static! {
1906    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1907}
1908
1909pub struct ProtocolVersion(u32);
1910
1911impl Header for ProtocolVersion {
1912    fn name() -> &'static HeaderName {
1913        &ZED_PROTOCOL_VERSION
1914    }
1915
1916    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1917    where
1918        Self: Sized,
1919        I: Iterator<Item = &'i axum::http::HeaderValue>,
1920    {
1921        let version = values
1922            .next()
1923            .ok_or_else(axum::headers::Error::invalid)?
1924            .to_str()
1925            .map_err(|_| axum::headers::Error::invalid())?
1926            .parse()
1927            .map_err(|_| axum::headers::Error::invalid())?;
1928        Ok(Self(version))
1929    }
1930
1931    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1932        values.extend([self.0.to_string().parse().unwrap()]);
1933    }
1934}
1935
1936pub fn routes(server: Arc<Server>) -> Router<Body> {
1937    Router::new()
1938        .route("/rpc", get(handle_websocket_request))
1939        .layer(
1940            ServiceBuilder::new()
1941                .layer(Extension(server.app_state.clone()))
1942                .layer(middleware::from_fn(auth::validate_header)),
1943        )
1944        .route("/metrics", get(handle_metrics))
1945        .layer(Extension(server))
1946}
1947
1948pub async fn handle_websocket_request(
1949    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1950    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1951    Extension(server): Extension<Arc<Server>>,
1952    Extension(user): Extension<User>,
1953    ws: WebSocketUpgrade,
1954) -> axum::response::Response {
1955    if protocol_version != rpc::PROTOCOL_VERSION {
1956        return (
1957            StatusCode::UPGRADE_REQUIRED,
1958            "client must be upgraded".to_string(),
1959        )
1960            .into_response();
1961    }
1962    let socket_address = socket_address.to_string();
1963    ws.on_upgrade(move |socket| {
1964        use util::ResultExt;
1965        let socket = socket
1966            .map_ok(to_tungstenite_message)
1967            .err_into()
1968            .with(|message| async move { Ok(to_axum_message(message)) });
1969        let connection = Connection::new(Box::pin(socket));
1970        async move {
1971            server
1972                .handle_connection(connection, socket_address, user, None, RealExecutor)
1973                .await
1974                .log_err();
1975        }
1976    })
1977}
1978
1979pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1980    // We call `store_mut` here for its side effects of updating metrics.
1981    let metrics = server.store().await.metrics();
1982    METRIC_CONNECTIONS.set(metrics.connections as _);
1983    METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1984    METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1985    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1986
1987    let encoder = prometheus::TextEncoder::new();
1988    let metric_families = prometheus::gather();
1989    match encoder.encode_to_string(&metric_families) {
1990        Ok(string) => (StatusCode::OK, string).into_response(),
1991        Err(error) => (
1992            StatusCode::INTERNAL_SERVER_ERROR,
1993            format!("failed to encode metrics {:?}", error),
1994        )
1995            .into_response(),
1996    }
1997}
1998
1999fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2000    match message {
2001        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2002        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2003        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2004        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2005        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2006            code: frame.code.into(),
2007            reason: frame.reason,
2008        })),
2009    }
2010}
2011
2012fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2013    match message {
2014        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2015        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2016        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2017        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2018        AxumMessage::Close(frame) => {
2019            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2020                code: frame.code.into(),
2021                reason: frame.reason,
2022            }))
2023        }
2024    }
2025}
2026
2027pub trait ResultExt {
2028    type Ok;
2029
2030    fn trace_err(self) -> Option<Self::Ok>;
2031}
2032
2033impl<T, E> ResultExt for Result<T, E>
2034where
2035    E: std::fmt::Debug,
2036{
2037    type Ok = T;
2038
2039    fn trace_err(self) -> Option<T> {
2040        match self {
2041            Ok(value) => Some(value),
2042            Err(error) => {
2043                tracing::error!("{:?}", error);
2044                None
2045            }
2046        }
2047    }
2048}