rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, ProjectId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52use time::OffsetDateTime;
  53use tokio::{
  54    sync::{Mutex, MutexGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub use store::{Store, Worktree};
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_REGISTERED_PROJECTS: IntGauge =
  66        register_int_gauge!("registered_projects", "number of registered projects").unwrap();
  67    static ref METRIC_ACTIVE_PROJECTS: IntGauge =
  68        register_int_gauge!("active_projects", "number of active projects").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    server: Arc<Server>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.server.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91
  92    fn into_receipt(self) -> Receipt<R> {
  93        self.responded.store(true, SeqCst);
  94        self.receipt
  95    }
  96}
  97
  98pub struct Server {
  99    peer: Arc<Peer>,
 100    pub(crate) store: Mutex<Store>,
 101    app_state: Arc<AppState>,
 102    handlers: HashMap<TypeId, MessageHandler>,
 103    notifications: Option<mpsc::UnboundedSender<()>>,
 104}
 105
 106pub trait Executor: Send + Clone {
 107    type Sleep: Send + Future;
 108    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 109    fn sleep(&self, duration: Duration) -> Self::Sleep;
 110}
 111
 112#[derive(Clone)]
 113pub struct RealExecutor;
 114
 115const MESSAGE_COUNT_PER_PAGE: usize = 100;
 116const MAX_MESSAGE_LEN: usize = 1024;
 117
 118pub(crate) struct StoreGuard<'a> {
 119    guard: MutexGuard<'a, Store>,
 120    _not_send: PhantomData<Rc<()>>,
 121}
 122
 123#[derive(Serialize)]
 124pub struct ServerSnapshot<'a> {
 125    peer: &'a Peer,
 126    #[serde(serialize_with = "serialize_deref")]
 127    store: StoreGuard<'a>,
 128}
 129
 130pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 131where
 132    S: Serializer,
 133    T: Deref<Target = U>,
 134    U: Serialize,
 135{
 136    Serialize::serialize(value.deref(), serializer)
 137}
 138
 139impl Server {
 140    pub fn new(
 141        app_state: Arc<AppState>,
 142        notifications: Option<mpsc::UnboundedSender<()>>,
 143    ) -> Arc<Self> {
 144        let mut server = Self {
 145            peer: Peer::new(),
 146            app_state,
 147            store: Default::default(),
 148            handlers: Default::default(),
 149            notifications,
 150        };
 151
 152        server
 153            .add_request_handler(Server::ping)
 154            .add_request_handler(Server::register_project)
 155            .add_request_handler(Server::unregister_project)
 156            .add_request_handler(Server::join_project)
 157            .add_message_handler(Server::leave_project)
 158            .add_message_handler(Server::respond_to_join_project_request)
 159            .add_message_handler(Server::update_project)
 160            .add_message_handler(Server::register_project_activity)
 161            .add_request_handler(Server::update_worktree)
 162            .add_message_handler(Server::update_worktree_extensions)
 163            .add_message_handler(Server::start_language_server)
 164            .add_message_handler(Server::update_language_server)
 165            .add_message_handler(Server::update_diagnostic_summary)
 166            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 167            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 169            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 170            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 172            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 173            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 174            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 175            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 176            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 177            .add_request_handler(
 178                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 179            )
 180            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 181            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 182            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 183            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 184            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 185            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 186            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 187            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 188            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 189            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 190            .add_request_handler(Server::update_buffer)
 191            .add_message_handler(Server::update_buffer_file)
 192            .add_message_handler(Server::buffer_reloaded)
 193            .add_message_handler(Server::buffer_saved)
 194            .add_request_handler(Server::save_buffer)
 195            .add_request_handler(Server::get_channels)
 196            .add_request_handler(Server::get_users)
 197            .add_request_handler(Server::fuzzy_search_users)
 198            .add_request_handler(Server::request_contact)
 199            .add_request_handler(Server::remove_contact)
 200            .add_request_handler(Server::respond_to_contact_request)
 201            .add_request_handler(Server::join_channel)
 202            .add_message_handler(Server::leave_channel)
 203            .add_request_handler(Server::send_channel_message)
 204            .add_request_handler(Server::follow)
 205            .add_message_handler(Server::unfollow)
 206            .add_message_handler(Server::update_followers)
 207            .add_request_handler(Server::get_channel_messages);
 208
 209        Arc::new(server)
 210    }
 211
 212    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 213    where
 214        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 215        Fut: 'static + Send + Future<Output = Result<()>>,
 216        M: EnvelopedMessage,
 217    {
 218        let prev_handler = self.handlers.insert(
 219            TypeId::of::<M>(),
 220            Box::new(move |server, envelope| {
 221                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 222                let span = info_span!(
 223                    "handle message",
 224                    payload_type = envelope.payload_type_name()
 225                );
 226                span.in_scope(|| {
 227                    tracing::info!(
 228                        payload_type = envelope.payload_type_name(),
 229                        "message received"
 230                    );
 231                });
 232                let future = (handler)(server, *envelope);
 233                async move {
 234                    if let Err(error) = future.await {
 235                        tracing::error!(%error, "error handling message");
 236                    }
 237                }
 238                .instrument(span)
 239                .boxed()
 240            }),
 241        );
 242        if prev_handler.is_some() {
 243            panic!("registered a handler for the same message twice");
 244        }
 245        self
 246    }
 247
 248    /// Handle a request while holding a lock to the store. This is useful when we're registering
 249    /// a connection but we want to respond on the connection before anybody else can send on it.
 250    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 251    where
 252        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 253        Fut: Send + Future<Output = Result<()>>,
 254        M: RequestMessage,
 255    {
 256        let handler = Arc::new(handler);
 257        self.add_message_handler(move |server, envelope| {
 258            let receipt = envelope.receipt();
 259            let handler = handler.clone();
 260            async move {
 261                let responded = Arc::new(AtomicBool::default());
 262                let response = Response {
 263                    server: server.clone(),
 264                    responded: responded.clone(),
 265                    receipt: envelope.receipt(),
 266                };
 267                match (handler)(server.clone(), envelope, response).await {
 268                    Ok(()) => {
 269                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 270                            Ok(())
 271                        } else {
 272                            Err(anyhow!("handler did not send a response"))?
 273                        }
 274                    }
 275                    Err(error) => {
 276                        server.peer.respond_with_error(
 277                            receipt,
 278                            proto::Error {
 279                                message: error.to_string(),
 280                            },
 281                        )?;
 282                        Err(error)
 283                    }
 284                }
 285            }
 286        })
 287    }
 288
 289    /// Start a long lived task that records which users are active in which projects.
 290    pub fn start_recording_project_activity<E: 'static + Executor>(
 291        self: &Arc<Self>,
 292        interval: Duration,
 293        executor: E,
 294    ) {
 295        executor.spawn_detached({
 296            let this = Arc::downgrade(self);
 297            let executor = executor.clone();
 298            async move {
 299                let mut period_start = OffsetDateTime::now_utc();
 300                let mut active_projects = Vec::<(UserId, ProjectId)>::new();
 301                loop {
 302                    let sleep = executor.sleep(interval);
 303                    sleep.await;
 304                    let this = if let Some(this) = this.upgrade() {
 305                        this
 306                    } else {
 307                        break;
 308                    };
 309
 310                    active_projects.clear();
 311                    active_projects.extend(this.store().await.projects().flat_map(
 312                        |(project_id, project)| {
 313                            project.guests.values().chain([&project.host]).filter_map(
 314                                |collaborator| {
 315                                    if !collaborator.admin
 316                                        && collaborator
 317                                            .last_activity
 318                                            .map_or(false, |activity| activity > period_start)
 319                                    {
 320                                        Some((collaborator.user_id, *project_id))
 321                                    } else {
 322                                        None
 323                                    }
 324                                },
 325                            )
 326                        },
 327                    ));
 328
 329                    let period_end = OffsetDateTime::now_utc();
 330                    this.app_state
 331                        .db
 332                        .record_user_activity(period_start..period_end, &active_projects)
 333                        .await
 334                        .trace_err();
 335                    period_start = period_end;
 336                }
 337            }
 338        });
 339    }
 340
 341    pub fn handle_connection<E: Executor>(
 342        self: &Arc<Self>,
 343        connection: Connection,
 344        address: String,
 345        user: User,
 346        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 347        executor: E,
 348    ) -> impl Future<Output = Result<()>> {
 349        let mut this = self.clone();
 350        let user_id = user.id;
 351        let login = user.github_login;
 352        let span = info_span!("handle connection", %user_id, %login, %address);
 353        async move {
 354            let (connection_id, handle_io, mut incoming_rx) = this
 355                .peer
 356                .add_connection(connection, {
 357                    let executor = executor.clone();
 358                    move |duration| {
 359                        let timer = executor.sleep(duration);
 360                        async move {
 361                            timer.await;
 362                        }
 363                    }
 364                })
 365                .await;
 366
 367            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 368
 369            if let Some(send_connection_id) = send_connection_id.as_mut() {
 370                let _ = send_connection_id.send(connection_id).await;
 371            }
 372
 373            if !user.connected_once {
 374                this.peer.send(connection_id, proto::ShowContacts {})?;
 375                this.app_state.db.set_user_connected_once(user_id, true).await?;
 376            }
 377
 378            let (contacts, invite_code) = future::try_join(
 379                this.app_state.db.get_contacts(user_id),
 380                this.app_state.db.get_invite_code_for_user(user_id)
 381            ).await?;
 382
 383            {
 384                let mut store = this.store().await;
 385                store.add_connection(connection_id, user_id, user.admin);
 386                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 387
 388                if let Some((code, count)) = invite_code {
 389                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 390                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 391                        count,
 392                    })?;
 393                }
 394            }
 395            this.update_user_contacts(user_id).await?;
 396
 397            let handle_io = handle_io.fuse();
 398            futures::pin_mut!(handle_io);
 399
 400            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 401            // This prevents deadlocks when e.g., client A performs a request to client B and
 402            // client B performs a request to client A. If both clients stop processing further
 403            // messages until their respective request completes, they won't have a chance to
 404            // respond to the other client's request and cause a deadlock.
 405            //
 406            // This arrangement ensures we will attempt to process earlier messages first, but fall
 407            // back to processing messages arrived later in the spirit of making progress.
 408            let mut foreground_message_handlers = FuturesUnordered::new();
 409            loop {
 410                let next_message = incoming_rx.next().fuse();
 411                futures::pin_mut!(next_message);
 412                futures::select_biased! {
 413                    result = handle_io => {
 414                        if let Err(error) = result {
 415                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 416                        }
 417                        break;
 418                    }
 419                    _ = foreground_message_handlers.next() => {}
 420                    message = next_message => {
 421                        if let Some(message) = message {
 422                            let type_name = message.payload_type_name();
 423                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 424                            let span_enter = span.enter();
 425                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 426                                let notifications = this.notifications.clone();
 427                                let is_background = message.is_background();
 428                                let handle_message = (handler)(this.clone(), message);
 429
 430                                drop(span_enter);
 431                                let handle_message = async move {
 432                                    handle_message.await;
 433                                    if let Some(mut notifications) = notifications {
 434                                        let _ = notifications.send(()).await;
 435                                    }
 436                                }.instrument(span);
 437
 438                                if is_background {
 439                                    executor.spawn_detached(handle_message);
 440                                } else {
 441                                    foreground_message_handlers.push(handle_message);
 442                                }
 443                            } else {
 444                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 445                            }
 446                        } else {
 447                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 448                            break;
 449                        }
 450                    }
 451                }
 452            }
 453
 454            drop(foreground_message_handlers);
 455            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 456            if let Err(error) = this.sign_out(connection_id).await {
 457                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 458            }
 459
 460            Ok(())
 461        }.instrument(span)
 462    }
 463
 464    #[instrument(skip(self), err)]
 465    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 466        self.peer.disconnect(connection_id);
 467
 468        let mut projects_to_unregister = Vec::new();
 469        let removed_user_id;
 470        {
 471            let mut store = self.store().await;
 472            let removed_connection = store.remove_connection(connection_id)?;
 473
 474            for (project_id, project) in removed_connection.hosted_projects {
 475                projects_to_unregister.push(project_id);
 476                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 477                    self.peer.send(
 478                        conn_id,
 479                        proto::UnregisterProject {
 480                            project_id: project_id.to_proto(),
 481                        },
 482                    )
 483                });
 484
 485                for (_, receipts) in project.join_requests {
 486                    for receipt in receipts {
 487                        self.peer.respond(
 488                            receipt,
 489                            proto::JoinProjectResponse {
 490                                variant: Some(proto::join_project_response::Variant::Decline(
 491                                    proto::join_project_response::Decline {
 492                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 493                                    },
 494                                )),
 495                            },
 496                        )?;
 497                    }
 498                }
 499            }
 500
 501            for project_id in removed_connection.guest_project_ids {
 502                if let Some(project) = store.project(project_id).trace_err() {
 503                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 504                        self.peer.send(
 505                            conn_id,
 506                            proto::RemoveProjectCollaborator {
 507                                project_id: project_id.to_proto(),
 508                                peer_id: connection_id.0,
 509                            },
 510                        )
 511                    });
 512                    if project.guests.is_empty() {
 513                        self.peer
 514                            .send(
 515                                project.host_connection_id,
 516                                proto::ProjectUnshared {
 517                                    project_id: project_id.to_proto(),
 518                                },
 519                            )
 520                            .trace_err();
 521                    }
 522                }
 523            }
 524
 525            removed_user_id = removed_connection.user_id;
 526        };
 527
 528        self.update_user_contacts(removed_user_id).await.trace_err();
 529
 530        for project_id in projects_to_unregister {
 531            self.app_state
 532                .db
 533                .unregister_project(project_id)
 534                .await
 535                .trace_err();
 536        }
 537
 538        Ok(())
 539    }
 540
 541    pub async fn invite_code_redeemed(
 542        self: &Arc<Self>,
 543        code: &str,
 544        invitee_id: UserId,
 545    ) -> Result<()> {
 546        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 547        let store = self.store().await;
 548        let invitee_contact = store.contact_for_user(invitee_id, true);
 549        for connection_id in store.connection_ids_for_user(user.id) {
 550            self.peer.send(
 551                connection_id,
 552                proto::UpdateContacts {
 553                    contacts: vec![invitee_contact.clone()],
 554                    ..Default::default()
 555                },
 556            )?;
 557            self.peer.send(
 558                connection_id,
 559                proto::UpdateInviteInfo {
 560                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 561                    count: user.invite_count as u32,
 562                },
 563            )?;
 564        }
 565        Ok(())
 566    }
 567
 568    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 569        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 570            if let Some(invite_code) = &user.invite_code {
 571                let store = self.store().await;
 572                for connection_id in store.connection_ids_for_user(user_id) {
 573                    self.peer.send(
 574                        connection_id,
 575                        proto::UpdateInviteInfo {
 576                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 577                            count: user.invite_count as u32,
 578                        },
 579                    )?;
 580                }
 581            }
 582        }
 583        Ok(())
 584    }
 585
 586    async fn ping(
 587        self: Arc<Server>,
 588        _: TypedEnvelope<proto::Ping>,
 589        response: Response<proto::Ping>,
 590    ) -> Result<()> {
 591        response.send(proto::Ack {})?;
 592        Ok(())
 593    }
 594
 595    async fn register_project(
 596        self: Arc<Server>,
 597        request: TypedEnvelope<proto::RegisterProject>,
 598        response: Response<proto::RegisterProject>,
 599    ) -> Result<()> {
 600        let user_id = self
 601            .store()
 602            .await
 603            .user_id_for_connection(request.sender_id)?;
 604        let project_id = self.app_state.db.register_project(user_id).await?;
 605        self.store().await.register_project(
 606            request.sender_id,
 607            project_id,
 608            request.payload.online,
 609        )?;
 610
 611        response.send(proto::RegisterProjectResponse {
 612            project_id: project_id.to_proto(),
 613        })?;
 614
 615        Ok(())
 616    }
 617
 618    async fn unregister_project(
 619        self: Arc<Server>,
 620        request: TypedEnvelope<proto::UnregisterProject>,
 621        response: Response<proto::UnregisterProject>,
 622    ) -> Result<()> {
 623        let project_id = ProjectId::from_proto(request.payload.project_id);
 624        let (user_id, project) = {
 625            let mut state = self.store().await;
 626            let project = state.unregister_project(project_id, request.sender_id)?;
 627            (state.user_id_for_connection(request.sender_id)?, project)
 628        };
 629        self.app_state.db.unregister_project(project_id).await?;
 630
 631        broadcast(
 632            request.sender_id,
 633            project.guests.keys().copied(),
 634            |conn_id| {
 635                self.peer.send(
 636                    conn_id,
 637                    proto::UnregisterProject {
 638                        project_id: project_id.to_proto(),
 639                    },
 640                )
 641            },
 642        );
 643        for (_, receipts) in project.join_requests {
 644            for receipt in receipts {
 645                self.peer.respond(
 646                    receipt,
 647                    proto::JoinProjectResponse {
 648                        variant: Some(proto::join_project_response::Variant::Decline(
 649                            proto::join_project_response::Decline {
 650                                reason: proto::join_project_response::decline::Reason::Closed
 651                                    as i32,
 652                            },
 653                        )),
 654                    },
 655                )?;
 656            }
 657        }
 658
 659        // Send out the `UpdateContacts` message before responding to the unregister
 660        // request. This way, when the project's host can keep track of the project's
 661        // remote id until after they've received the `UpdateContacts` message for
 662        // themself.
 663        self.update_user_contacts(user_id).await?;
 664        response.send(proto::Ack {})?;
 665
 666        Ok(())
 667    }
 668
 669    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 670        let contacts = self.app_state.db.get_contacts(user_id).await?;
 671        let store = self.store().await;
 672        let updated_contact = store.contact_for_user(user_id, false);
 673        for contact in contacts {
 674            if let db::Contact::Accepted {
 675                user_id: contact_user_id,
 676                ..
 677            } = contact
 678            {
 679                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 680                    self.peer
 681                        .send(
 682                            contact_conn_id,
 683                            proto::UpdateContacts {
 684                                contacts: vec![updated_contact.clone()],
 685                                remove_contacts: Default::default(),
 686                                incoming_requests: Default::default(),
 687                                remove_incoming_requests: Default::default(),
 688                                outgoing_requests: Default::default(),
 689                                remove_outgoing_requests: Default::default(),
 690                            },
 691                        )
 692                        .trace_err();
 693                }
 694            }
 695        }
 696        Ok(())
 697    }
 698
 699    async fn join_project(
 700        self: Arc<Server>,
 701        request: TypedEnvelope<proto::JoinProject>,
 702        response: Response<proto::JoinProject>,
 703    ) -> Result<()> {
 704        let project_id = ProjectId::from_proto(request.payload.project_id);
 705
 706        let host_user_id;
 707        let guest_user_id;
 708        let host_connection_id;
 709        {
 710            let state = self.store().await;
 711            let project = state.project(project_id)?;
 712            host_user_id = project.host.user_id;
 713            host_connection_id = project.host_connection_id;
 714            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 715        };
 716
 717        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 718        let has_contact = self
 719            .app_state
 720            .db
 721            .has_contact(guest_user_id, host_user_id)
 722            .await?;
 723        if !has_contact {
 724            return Err(anyhow!("no such project"))?;
 725        }
 726
 727        self.store().await.request_join_project(
 728            guest_user_id,
 729            project_id,
 730            response.into_receipt(),
 731        )?;
 732        self.peer.send(
 733            host_connection_id,
 734            proto::RequestJoinProject {
 735                project_id: project_id.to_proto(),
 736                requester_id: guest_user_id.to_proto(),
 737            },
 738        )?;
 739        Ok(())
 740    }
 741
 742    async fn respond_to_join_project_request(
 743        self: Arc<Server>,
 744        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 745    ) -> Result<()> {
 746        let host_user_id;
 747
 748        {
 749            let mut state = self.store().await;
 750            let project_id = ProjectId::from_proto(request.payload.project_id);
 751            let project = state.project(project_id)?;
 752            if project.host_connection_id != request.sender_id {
 753                Err(anyhow!("no such connection"))?;
 754            }
 755
 756            host_user_id = project.host.user_id;
 757            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 758
 759            if !request.payload.allow {
 760                let receipts = state
 761                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 762                    .ok_or_else(|| anyhow!("no such request"))?;
 763                for receipt in receipts {
 764                    self.peer.respond(
 765                        receipt,
 766                        proto::JoinProjectResponse {
 767                            variant: Some(proto::join_project_response::Variant::Decline(
 768                                proto::join_project_response::Decline {
 769                                    reason: proto::join_project_response::decline::Reason::Declined
 770                                        as i32,
 771                                },
 772                            )),
 773                        },
 774                    )?;
 775                }
 776                return Ok(());
 777            }
 778
 779            let (receipts_with_replica_ids, project) = state
 780                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 781                .ok_or_else(|| anyhow!("no such request"))?;
 782
 783            let peer_count = project.guests.len();
 784            let mut collaborators = Vec::with_capacity(peer_count);
 785            collaborators.push(proto::Collaborator {
 786                peer_id: project.host_connection_id.0,
 787                replica_id: 0,
 788                user_id: project.host.user_id.to_proto(),
 789            });
 790            let worktrees = project
 791                .worktrees
 792                .iter()
 793                .map(|(id, worktree)| proto::WorktreeMetadata {
 794                    id: *id,
 795                    root_name: worktree.root_name.clone(),
 796                    visible: worktree.visible,
 797                })
 798                .collect::<Vec<_>>();
 799
 800            // Add all guests other than the requesting user's own connections as collaborators
 801            for (guest_conn_id, guest) in &project.guests {
 802                if receipts_with_replica_ids
 803                    .iter()
 804                    .all(|(receipt, _)| receipt.sender_id != *guest_conn_id)
 805                {
 806                    collaborators.push(proto::Collaborator {
 807                        peer_id: guest_conn_id.0,
 808                        replica_id: guest.replica_id as u32,
 809                        user_id: guest.user_id.to_proto(),
 810                    });
 811                }
 812            }
 813
 814            for conn_id in project.connection_ids() {
 815                for (receipt, replica_id) in &receipts_with_replica_ids {
 816                    if conn_id != receipt.sender_id {
 817                        self.peer.send(
 818                            conn_id,
 819                            proto::AddProjectCollaborator {
 820                                project_id: project_id.to_proto(),
 821                                collaborator: Some(proto::Collaborator {
 822                                    peer_id: receipt.sender_id.0,
 823                                    replica_id: *replica_id as u32,
 824                                    user_id: guest_user_id.to_proto(),
 825                                }),
 826                            },
 827                        )?;
 828                    }
 829                }
 830            }
 831
 832            // First, we send the metadata associated with each worktree.
 833            for (receipt, replica_id) in &receipts_with_replica_ids {
 834                self.peer.respond(
 835                    receipt.clone(),
 836                    proto::JoinProjectResponse {
 837                        variant: Some(proto::join_project_response::Variant::Accept(
 838                            proto::join_project_response::Accept {
 839                                worktrees: worktrees.clone(),
 840                                replica_id: *replica_id as u32,
 841                                collaborators: collaborators.clone(),
 842                                language_servers: project.language_servers.clone(),
 843                            },
 844                        )),
 845                    },
 846                )?;
 847            }
 848
 849            for (worktree_id, worktree) in &project.worktrees {
 850                #[cfg(any(test, feature = "test-support"))]
 851                const MAX_CHUNK_SIZE: usize = 2;
 852                #[cfg(not(any(test, feature = "test-support")))]
 853                const MAX_CHUNK_SIZE: usize = 256;
 854
 855                // Stream this worktree's entries.
 856                let message = proto::UpdateWorktree {
 857                    project_id: project_id.to_proto(),
 858                    worktree_id: *worktree_id,
 859                    root_name: worktree.root_name.clone(),
 860                    updated_entries: worktree.entries.values().cloned().collect(),
 861                    removed_entries: Default::default(),
 862                    scan_id: worktree.scan_id,
 863                    is_last_update: worktree.is_complete,
 864                };
 865                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
 866                    for (receipt, _) in &receipts_with_replica_ids {
 867                        self.peer.send(receipt.sender_id, update.clone())?;
 868                    }
 869                }
 870
 871                // Stream this worktree's diagnostics.
 872                for summary in worktree.diagnostic_summaries.values() {
 873                    for (receipt, _) in &receipts_with_replica_ids {
 874                        self.peer.send(
 875                            receipt.sender_id,
 876                            proto::UpdateDiagnosticSummary {
 877                                project_id: project_id.to_proto(),
 878                                worktree_id: *worktree_id,
 879                                summary: Some(summary.clone()),
 880                            },
 881                        )?;
 882                    }
 883                }
 884            }
 885        }
 886
 887        self.update_user_contacts(host_user_id).await?;
 888        Ok(())
 889    }
 890
 891    async fn leave_project(
 892        self: Arc<Server>,
 893        request: TypedEnvelope<proto::LeaveProject>,
 894    ) -> Result<()> {
 895        let sender_id = request.sender_id;
 896        let project_id = ProjectId::from_proto(request.payload.project_id);
 897        let project;
 898        {
 899            let mut store = self.store().await;
 900            project = store.leave_project(sender_id, project_id)?;
 901            tracing::info!(
 902                %project_id,
 903                host_user_id = %project.host_user_id,
 904                host_connection_id = %project.host_connection_id,
 905                "leave project"
 906            );
 907
 908            if project.remove_collaborator {
 909                broadcast(sender_id, project.connection_ids, |conn_id| {
 910                    self.peer.send(
 911                        conn_id,
 912                        proto::RemoveProjectCollaborator {
 913                            project_id: project_id.to_proto(),
 914                            peer_id: sender_id.0,
 915                        },
 916                    )
 917                });
 918            }
 919
 920            if let Some(requester_id) = project.cancel_request {
 921                self.peer.send(
 922                    project.host_connection_id,
 923                    proto::JoinProjectRequestCancelled {
 924                        project_id: project_id.to_proto(),
 925                        requester_id: requester_id.to_proto(),
 926                    },
 927                )?;
 928            }
 929
 930            if project.unshare {
 931                self.peer.send(
 932                    project.host_connection_id,
 933                    proto::ProjectUnshared {
 934                        project_id: project_id.to_proto(),
 935                    },
 936                )?;
 937            }
 938        }
 939        self.update_user_contacts(project.host_user_id).await?;
 940        Ok(())
 941    }
 942
 943    async fn update_project(
 944        self: Arc<Server>,
 945        request: TypedEnvelope<proto::UpdateProject>,
 946    ) -> Result<()> {
 947        let project_id = ProjectId::from_proto(request.payload.project_id);
 948        let user_id;
 949        {
 950            let mut state = self.store().await;
 951            user_id = state.user_id_for_connection(request.sender_id)?;
 952            let guest_connection_ids = state
 953                .read_project(project_id, request.sender_id)?
 954                .guest_connection_ids();
 955            let unshared_project = state.update_project(
 956                project_id,
 957                &request.payload.worktrees,
 958                request.payload.online,
 959                request.sender_id,
 960            )?;
 961
 962            if let Some(unshared_project) = unshared_project {
 963                broadcast(
 964                    request.sender_id,
 965                    unshared_project.guests.keys().copied(),
 966                    |conn_id| {
 967                        self.peer.send(
 968                            conn_id,
 969                            proto::UnregisterProject {
 970                                project_id: project_id.to_proto(),
 971                            },
 972                        )
 973                    },
 974                );
 975                for (_, receipts) in unshared_project.pending_join_requests {
 976                    for receipt in receipts {
 977                        self.peer.respond(
 978                            receipt,
 979                            proto::JoinProjectResponse {
 980                                variant: Some(proto::join_project_response::Variant::Decline(
 981                                    proto::join_project_response::Decline {
 982                                        reason:
 983                                            proto::join_project_response::decline::Reason::Closed
 984                                                as i32,
 985                                    },
 986                                )),
 987                            },
 988                        )?;
 989                    }
 990                }
 991            } else {
 992                broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 993                    self.peer.forward_send(
 994                        request.sender_id,
 995                        connection_id,
 996                        request.payload.clone(),
 997                    )
 998                });
 999            }
1000        };
1001
1002        self.update_user_contacts(user_id).await?;
1003        Ok(())
1004    }
1005
1006    async fn register_project_activity(
1007        self: Arc<Server>,
1008        request: TypedEnvelope<proto::RegisterProjectActivity>,
1009    ) -> Result<()> {
1010        self.store().await.register_project_activity(
1011            ProjectId::from_proto(request.payload.project_id),
1012            request.sender_id,
1013        )?;
1014        Ok(())
1015    }
1016
1017    async fn update_worktree(
1018        self: Arc<Server>,
1019        request: TypedEnvelope<proto::UpdateWorktree>,
1020        response: Response<proto::UpdateWorktree>,
1021    ) -> Result<()> {
1022        let project_id = ProjectId::from_proto(request.payload.project_id);
1023        let worktree_id = request.payload.worktree_id;
1024        let (connection_ids, metadata_changed) = {
1025            let mut store = self.store().await;
1026            let (connection_ids, metadata_changed) = store.update_worktree(
1027                request.sender_id,
1028                project_id,
1029                worktree_id,
1030                &request.payload.root_name,
1031                &request.payload.removed_entries,
1032                &request.payload.updated_entries,
1033                request.payload.scan_id,
1034                request.payload.is_last_update,
1035            )?;
1036            (connection_ids, metadata_changed)
1037        };
1038
1039        broadcast(request.sender_id, connection_ids, |connection_id| {
1040            self.peer
1041                .forward_send(request.sender_id, connection_id, request.payload.clone())
1042        });
1043        if metadata_changed {
1044            let user_id = self
1045                .store()
1046                .await
1047                .user_id_for_connection(request.sender_id)?;
1048            self.update_user_contacts(user_id).await?;
1049        }
1050        response.send(proto::Ack {})?;
1051        Ok(())
1052    }
1053
1054    async fn update_worktree_extensions(
1055        self: Arc<Server>,
1056        request: TypedEnvelope<proto::UpdateWorktreeExtensions>,
1057    ) -> Result<()> {
1058        let project_id = ProjectId::from_proto(request.payload.project_id);
1059        let worktree_id = request.payload.worktree_id;
1060        let extensions = request
1061            .payload
1062            .extensions
1063            .into_iter()
1064            .zip(request.payload.counts)
1065            .collect();
1066        self.app_state
1067            .db
1068            .update_worktree_extensions(project_id, worktree_id, extensions)
1069            .await?;
1070        Ok(())
1071    }
1072
1073    async fn update_diagnostic_summary(
1074        self: Arc<Server>,
1075        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
1076    ) -> Result<()> {
1077        let summary = request
1078            .payload
1079            .summary
1080            .clone()
1081            .ok_or_else(|| anyhow!("invalid summary"))?;
1082        let receiver_ids = self.store().await.update_diagnostic_summary(
1083            ProjectId::from_proto(request.payload.project_id),
1084            request.payload.worktree_id,
1085            request.sender_id,
1086            summary,
1087        )?;
1088
1089        broadcast(request.sender_id, receiver_ids, |connection_id| {
1090            self.peer
1091                .forward_send(request.sender_id, connection_id, request.payload.clone())
1092        });
1093        Ok(())
1094    }
1095
1096    async fn start_language_server(
1097        self: Arc<Server>,
1098        request: TypedEnvelope<proto::StartLanguageServer>,
1099    ) -> Result<()> {
1100        let receiver_ids = self.store().await.start_language_server(
1101            ProjectId::from_proto(request.payload.project_id),
1102            request.sender_id,
1103            request
1104                .payload
1105                .server
1106                .clone()
1107                .ok_or_else(|| anyhow!("invalid language server"))?,
1108        )?;
1109        broadcast(request.sender_id, receiver_ids, |connection_id| {
1110            self.peer
1111                .forward_send(request.sender_id, connection_id, request.payload.clone())
1112        });
1113        Ok(())
1114    }
1115
1116    async fn update_language_server(
1117        self: Arc<Server>,
1118        request: TypedEnvelope<proto::UpdateLanguageServer>,
1119    ) -> Result<()> {
1120        let receiver_ids = self.store().await.project_connection_ids(
1121            ProjectId::from_proto(request.payload.project_id),
1122            request.sender_id,
1123        )?;
1124        broadcast(request.sender_id, receiver_ids, |connection_id| {
1125            self.peer
1126                .forward_send(request.sender_id, connection_id, request.payload.clone())
1127        });
1128        Ok(())
1129    }
1130
1131    async fn forward_project_request<T>(
1132        self: Arc<Server>,
1133        request: TypedEnvelope<T>,
1134        response: Response<T>,
1135    ) -> Result<()>
1136    where
1137        T: EntityMessage + RequestMessage,
1138    {
1139        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1140        let host_connection_id = self
1141            .store()
1142            .await
1143            .read_project(project_id, request.sender_id)?
1144            .host_connection_id;
1145        let payload = self
1146            .peer
1147            .forward_request(request.sender_id, host_connection_id, request.payload)
1148            .await?;
1149
1150        // Ensure project still exists by the time we get the response from the host.
1151        self.store()
1152            .await
1153            .read_project(project_id, request.sender_id)?;
1154
1155        response.send(payload)?;
1156        Ok(())
1157    }
1158
1159    async fn save_buffer(
1160        self: Arc<Server>,
1161        request: TypedEnvelope<proto::SaveBuffer>,
1162        response: Response<proto::SaveBuffer>,
1163    ) -> Result<()> {
1164        let project_id = ProjectId::from_proto(request.payload.project_id);
1165        let host = self
1166            .store()
1167            .await
1168            .read_project(project_id, request.sender_id)?
1169            .host_connection_id;
1170        let response_payload = self
1171            .peer
1172            .forward_request(request.sender_id, host, request.payload.clone())
1173            .await?;
1174
1175        let mut guests = self
1176            .store()
1177            .await
1178            .read_project(project_id, request.sender_id)?
1179            .connection_ids();
1180        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1181        broadcast(host, guests, |conn_id| {
1182            self.peer
1183                .forward_send(host, conn_id, response_payload.clone())
1184        });
1185        response.send(response_payload)?;
1186        Ok(())
1187    }
1188
1189    async fn update_buffer(
1190        self: Arc<Server>,
1191        request: TypedEnvelope<proto::UpdateBuffer>,
1192        response: Response<proto::UpdateBuffer>,
1193    ) -> Result<()> {
1194        let project_id = ProjectId::from_proto(request.payload.project_id);
1195        let receiver_ids = {
1196            let mut store = self.store().await;
1197            store.register_project_activity(project_id, request.sender_id)?;
1198            store.project_connection_ids(project_id, request.sender_id)?
1199        };
1200
1201        broadcast(request.sender_id, receiver_ids, |connection_id| {
1202            self.peer
1203                .forward_send(request.sender_id, connection_id, request.payload.clone())
1204        });
1205        response.send(proto::Ack {})?;
1206        Ok(())
1207    }
1208
1209    async fn update_buffer_file(
1210        self: Arc<Server>,
1211        request: TypedEnvelope<proto::UpdateBufferFile>,
1212    ) -> Result<()> {
1213        let receiver_ids = self.store().await.project_connection_ids(
1214            ProjectId::from_proto(request.payload.project_id),
1215            request.sender_id,
1216        )?;
1217        broadcast(request.sender_id, receiver_ids, |connection_id| {
1218            self.peer
1219                .forward_send(request.sender_id, connection_id, request.payload.clone())
1220        });
1221        Ok(())
1222    }
1223
1224    async fn buffer_reloaded(
1225        self: Arc<Server>,
1226        request: TypedEnvelope<proto::BufferReloaded>,
1227    ) -> Result<()> {
1228        let receiver_ids = self.store().await.project_connection_ids(
1229            ProjectId::from_proto(request.payload.project_id),
1230            request.sender_id,
1231        )?;
1232        broadcast(request.sender_id, receiver_ids, |connection_id| {
1233            self.peer
1234                .forward_send(request.sender_id, connection_id, request.payload.clone())
1235        });
1236        Ok(())
1237    }
1238
1239    async fn buffer_saved(
1240        self: Arc<Server>,
1241        request: TypedEnvelope<proto::BufferSaved>,
1242    ) -> Result<()> {
1243        let receiver_ids = self.store().await.project_connection_ids(
1244            ProjectId::from_proto(request.payload.project_id),
1245            request.sender_id,
1246        )?;
1247        broadcast(request.sender_id, receiver_ids, |connection_id| {
1248            self.peer
1249                .forward_send(request.sender_id, connection_id, request.payload.clone())
1250        });
1251        Ok(())
1252    }
1253
1254    async fn follow(
1255        self: Arc<Self>,
1256        request: TypedEnvelope<proto::Follow>,
1257        response: Response<proto::Follow>,
1258    ) -> Result<()> {
1259        let project_id = ProjectId::from_proto(request.payload.project_id);
1260        let leader_id = ConnectionId(request.payload.leader_id);
1261        let follower_id = request.sender_id;
1262        {
1263            let mut store = self.store().await;
1264            if !store
1265                .project_connection_ids(project_id, follower_id)?
1266                .contains(&leader_id)
1267            {
1268                Err(anyhow!("no such peer"))?;
1269            }
1270
1271            store.register_project_activity(project_id, follower_id)?;
1272        }
1273
1274        let mut response_payload = self
1275            .peer
1276            .forward_request(request.sender_id, leader_id, request.payload)
1277            .await?;
1278        response_payload
1279            .views
1280            .retain(|view| view.leader_id != Some(follower_id.0));
1281        response.send(response_payload)?;
1282        Ok(())
1283    }
1284
1285    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1286        let project_id = ProjectId::from_proto(request.payload.project_id);
1287        let leader_id = ConnectionId(request.payload.leader_id);
1288        let mut store = self.store().await;
1289        if !store
1290            .project_connection_ids(project_id, request.sender_id)?
1291            .contains(&leader_id)
1292        {
1293            Err(anyhow!("no such peer"))?;
1294        }
1295        store.register_project_activity(project_id, request.sender_id)?;
1296        self.peer
1297            .forward_send(request.sender_id, leader_id, request.payload)?;
1298        Ok(())
1299    }
1300
1301    async fn update_followers(
1302        self: Arc<Self>,
1303        request: TypedEnvelope<proto::UpdateFollowers>,
1304    ) -> Result<()> {
1305        let project_id = ProjectId::from_proto(request.payload.project_id);
1306        let mut store = self.store().await;
1307        store.register_project_activity(project_id, request.sender_id)?;
1308        let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
1309        let leader_id = request
1310            .payload
1311            .variant
1312            .as_ref()
1313            .and_then(|variant| match variant {
1314                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1315                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1316                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1317            });
1318        for follower_id in &request.payload.follower_ids {
1319            let follower_id = ConnectionId(*follower_id);
1320            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1321                self.peer
1322                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1323            }
1324        }
1325        Ok(())
1326    }
1327
1328    async fn get_channels(
1329        self: Arc<Server>,
1330        request: TypedEnvelope<proto::GetChannels>,
1331        response: Response<proto::GetChannels>,
1332    ) -> Result<()> {
1333        let user_id = self
1334            .store()
1335            .await
1336            .user_id_for_connection(request.sender_id)?;
1337        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1338        response.send(proto::GetChannelsResponse {
1339            channels: channels
1340                .into_iter()
1341                .map(|chan| proto::Channel {
1342                    id: chan.id.to_proto(),
1343                    name: chan.name,
1344                })
1345                .collect(),
1346        })?;
1347        Ok(())
1348    }
1349
1350    async fn get_users(
1351        self: Arc<Server>,
1352        request: TypedEnvelope<proto::GetUsers>,
1353        response: Response<proto::GetUsers>,
1354    ) -> Result<()> {
1355        let user_ids = request
1356            .payload
1357            .user_ids
1358            .into_iter()
1359            .map(UserId::from_proto)
1360            .collect();
1361        let users = self
1362            .app_state
1363            .db
1364            .get_users_by_ids(user_ids)
1365            .await?
1366            .into_iter()
1367            .map(|user| proto::User {
1368                id: user.id.to_proto(),
1369                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1370                github_login: user.github_login,
1371            })
1372            .collect();
1373        response.send(proto::UsersResponse { users })?;
1374        Ok(())
1375    }
1376
1377    async fn fuzzy_search_users(
1378        self: Arc<Server>,
1379        request: TypedEnvelope<proto::FuzzySearchUsers>,
1380        response: Response<proto::FuzzySearchUsers>,
1381    ) -> Result<()> {
1382        let user_id = self
1383            .store()
1384            .await
1385            .user_id_for_connection(request.sender_id)?;
1386        let query = request.payload.query;
1387        let db = &self.app_state.db;
1388        let users = match query.len() {
1389            0 => vec![],
1390            1 | 2 => db
1391                .get_user_by_github_login(&query)
1392                .await?
1393                .into_iter()
1394                .collect(),
1395            _ => db.fuzzy_search_users(&query, 10).await?,
1396        };
1397        let users = users
1398            .into_iter()
1399            .filter(|user| user.id != user_id)
1400            .map(|user| proto::User {
1401                id: user.id.to_proto(),
1402                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1403                github_login: user.github_login,
1404            })
1405            .collect();
1406        response.send(proto::UsersResponse { users })?;
1407        Ok(())
1408    }
1409
1410    async fn request_contact(
1411        self: Arc<Server>,
1412        request: TypedEnvelope<proto::RequestContact>,
1413        response: Response<proto::RequestContact>,
1414    ) -> Result<()> {
1415        let requester_id = self
1416            .store()
1417            .await
1418            .user_id_for_connection(request.sender_id)?;
1419        let responder_id = UserId::from_proto(request.payload.responder_id);
1420        if requester_id == responder_id {
1421            return Err(anyhow!("cannot add yourself as a contact"))?;
1422        }
1423
1424        self.app_state
1425            .db
1426            .send_contact_request(requester_id, responder_id)
1427            .await?;
1428
1429        // Update outgoing contact requests of requester
1430        let mut update = proto::UpdateContacts::default();
1431        update.outgoing_requests.push(responder_id.to_proto());
1432        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1433            self.peer.send(connection_id, update.clone())?;
1434        }
1435
1436        // Update incoming contact requests of responder
1437        let mut update = proto::UpdateContacts::default();
1438        update
1439            .incoming_requests
1440            .push(proto::IncomingContactRequest {
1441                requester_id: requester_id.to_proto(),
1442                should_notify: true,
1443            });
1444        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1445            self.peer.send(connection_id, update.clone())?;
1446        }
1447
1448        response.send(proto::Ack {})?;
1449        Ok(())
1450    }
1451
1452    async fn respond_to_contact_request(
1453        self: Arc<Server>,
1454        request: TypedEnvelope<proto::RespondToContactRequest>,
1455        response: Response<proto::RespondToContactRequest>,
1456    ) -> Result<()> {
1457        let responder_id = self
1458            .store()
1459            .await
1460            .user_id_for_connection(request.sender_id)?;
1461        let requester_id = UserId::from_proto(request.payload.requester_id);
1462        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1463            self.app_state
1464                .db
1465                .dismiss_contact_notification(responder_id, requester_id)
1466                .await?;
1467        } else {
1468            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1469            self.app_state
1470                .db
1471                .respond_to_contact_request(responder_id, requester_id, accept)
1472                .await?;
1473
1474            let store = self.store().await;
1475            // Update responder with new contact
1476            let mut update = proto::UpdateContacts::default();
1477            if accept {
1478                update
1479                    .contacts
1480                    .push(store.contact_for_user(requester_id, false));
1481            }
1482            update
1483                .remove_incoming_requests
1484                .push(requester_id.to_proto());
1485            for connection_id in store.connection_ids_for_user(responder_id) {
1486                self.peer.send(connection_id, update.clone())?;
1487            }
1488
1489            // Update requester with new contact
1490            let mut update = proto::UpdateContacts::default();
1491            if accept {
1492                update
1493                    .contacts
1494                    .push(store.contact_for_user(responder_id, true));
1495            }
1496            update
1497                .remove_outgoing_requests
1498                .push(responder_id.to_proto());
1499            for connection_id in store.connection_ids_for_user(requester_id) {
1500                self.peer.send(connection_id, update.clone())?;
1501            }
1502        }
1503
1504        response.send(proto::Ack {})?;
1505        Ok(())
1506    }
1507
1508    async fn remove_contact(
1509        self: Arc<Server>,
1510        request: TypedEnvelope<proto::RemoveContact>,
1511        response: Response<proto::RemoveContact>,
1512    ) -> Result<()> {
1513        let requester_id = self
1514            .store()
1515            .await
1516            .user_id_for_connection(request.sender_id)?;
1517        let responder_id = UserId::from_proto(request.payload.user_id);
1518        self.app_state
1519            .db
1520            .remove_contact(requester_id, responder_id)
1521            .await?;
1522
1523        // Update outgoing contact requests of requester
1524        let mut update = proto::UpdateContacts::default();
1525        update
1526            .remove_outgoing_requests
1527            .push(responder_id.to_proto());
1528        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1529            self.peer.send(connection_id, update.clone())?;
1530        }
1531
1532        // Update incoming contact requests of responder
1533        let mut update = proto::UpdateContacts::default();
1534        update
1535            .remove_incoming_requests
1536            .push(requester_id.to_proto());
1537        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1538            self.peer.send(connection_id, update.clone())?;
1539        }
1540
1541        response.send(proto::Ack {})?;
1542        Ok(())
1543    }
1544
1545    async fn join_channel(
1546        self: Arc<Self>,
1547        request: TypedEnvelope<proto::JoinChannel>,
1548        response: Response<proto::JoinChannel>,
1549    ) -> Result<()> {
1550        let user_id = self
1551            .store()
1552            .await
1553            .user_id_for_connection(request.sender_id)?;
1554        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1555        if !self
1556            .app_state
1557            .db
1558            .can_user_access_channel(user_id, channel_id)
1559            .await?
1560        {
1561            Err(anyhow!("access denied"))?;
1562        }
1563
1564        self.store()
1565            .await
1566            .join_channel(request.sender_id, channel_id);
1567        let messages = self
1568            .app_state
1569            .db
1570            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1571            .await?
1572            .into_iter()
1573            .map(|msg| proto::ChannelMessage {
1574                id: msg.id.to_proto(),
1575                body: msg.body,
1576                timestamp: msg.sent_at.unix_timestamp() as u64,
1577                sender_id: msg.sender_id.to_proto(),
1578                nonce: Some(msg.nonce.as_u128().into()),
1579            })
1580            .collect::<Vec<_>>();
1581        response.send(proto::JoinChannelResponse {
1582            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1583            messages,
1584        })?;
1585        Ok(())
1586    }
1587
1588    async fn leave_channel(
1589        self: Arc<Self>,
1590        request: TypedEnvelope<proto::LeaveChannel>,
1591    ) -> Result<()> {
1592        let user_id = self
1593            .store()
1594            .await
1595            .user_id_for_connection(request.sender_id)?;
1596        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1597        if !self
1598            .app_state
1599            .db
1600            .can_user_access_channel(user_id, channel_id)
1601            .await?
1602        {
1603            Err(anyhow!("access denied"))?;
1604        }
1605
1606        self.store()
1607            .await
1608            .leave_channel(request.sender_id, channel_id);
1609
1610        Ok(())
1611    }
1612
1613    async fn send_channel_message(
1614        self: Arc<Self>,
1615        request: TypedEnvelope<proto::SendChannelMessage>,
1616        response: Response<proto::SendChannelMessage>,
1617    ) -> Result<()> {
1618        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1619        let user_id;
1620        let connection_ids;
1621        {
1622            let state = self.store().await;
1623            user_id = state.user_id_for_connection(request.sender_id)?;
1624            connection_ids = state.channel_connection_ids(channel_id)?;
1625        }
1626
1627        // Validate the message body.
1628        let body = request.payload.body.trim().to_string();
1629        if body.len() > MAX_MESSAGE_LEN {
1630            return Err(anyhow!("message is too long"))?;
1631        }
1632        if body.is_empty() {
1633            return Err(anyhow!("message can't be blank"))?;
1634        }
1635
1636        let timestamp = OffsetDateTime::now_utc();
1637        let nonce = request
1638            .payload
1639            .nonce
1640            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1641
1642        let message_id = self
1643            .app_state
1644            .db
1645            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1646            .await?
1647            .to_proto();
1648        let message = proto::ChannelMessage {
1649            sender_id: user_id.to_proto(),
1650            id: message_id,
1651            body,
1652            timestamp: timestamp.unix_timestamp() as u64,
1653            nonce: Some(nonce),
1654        };
1655        broadcast(request.sender_id, connection_ids, |conn_id| {
1656            self.peer.send(
1657                conn_id,
1658                proto::ChannelMessageSent {
1659                    channel_id: channel_id.to_proto(),
1660                    message: Some(message.clone()),
1661                },
1662            )
1663        });
1664        response.send(proto::SendChannelMessageResponse {
1665            message: Some(message),
1666        })?;
1667        Ok(())
1668    }
1669
1670    async fn get_channel_messages(
1671        self: Arc<Self>,
1672        request: TypedEnvelope<proto::GetChannelMessages>,
1673        response: Response<proto::GetChannelMessages>,
1674    ) -> Result<()> {
1675        let user_id = self
1676            .store()
1677            .await
1678            .user_id_for_connection(request.sender_id)?;
1679        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1680        if !self
1681            .app_state
1682            .db
1683            .can_user_access_channel(user_id, channel_id)
1684            .await?
1685        {
1686            Err(anyhow!("access denied"))?;
1687        }
1688
1689        let messages = self
1690            .app_state
1691            .db
1692            .get_channel_messages(
1693                channel_id,
1694                MESSAGE_COUNT_PER_PAGE,
1695                Some(MessageId::from_proto(request.payload.before_message_id)),
1696            )
1697            .await?
1698            .into_iter()
1699            .map(|msg| proto::ChannelMessage {
1700                id: msg.id.to_proto(),
1701                body: msg.body,
1702                timestamp: msg.sent_at.unix_timestamp() as u64,
1703                sender_id: msg.sender_id.to_proto(),
1704                nonce: Some(msg.nonce.as_u128().into()),
1705            })
1706            .collect::<Vec<_>>();
1707        response.send(proto::GetChannelMessagesResponse {
1708            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1709            messages,
1710        })?;
1711        Ok(())
1712    }
1713
1714    pub(crate) async fn store<'a>(&'a self) -> StoreGuard<'a> {
1715        #[cfg(test)]
1716        tokio::task::yield_now().await;
1717        let guard = self.store.lock().await;
1718        #[cfg(test)]
1719        tokio::task::yield_now().await;
1720        StoreGuard {
1721            guard,
1722            _not_send: PhantomData,
1723        }
1724    }
1725
1726    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1727        ServerSnapshot {
1728            store: self.store().await,
1729            peer: &self.peer,
1730        }
1731    }
1732}
1733
1734impl<'a> Deref for StoreGuard<'a> {
1735    type Target = Store;
1736
1737    fn deref(&self) -> &Self::Target {
1738        &*self.guard
1739    }
1740}
1741
1742impl<'a> DerefMut for StoreGuard<'a> {
1743    fn deref_mut(&mut self) -> &mut Self::Target {
1744        &mut *self.guard
1745    }
1746}
1747
1748impl<'a> Drop for StoreGuard<'a> {
1749    fn drop(&mut self) {
1750        #[cfg(test)]
1751        self.check_invariants();
1752    }
1753}
1754
1755impl Executor for RealExecutor {
1756    type Sleep = Sleep;
1757
1758    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1759        tokio::task::spawn(future);
1760    }
1761
1762    fn sleep(&self, duration: Duration) -> Self::Sleep {
1763        tokio::time::sleep(duration)
1764    }
1765}
1766
1767fn broadcast<F>(
1768    sender_id: ConnectionId,
1769    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1770    mut f: F,
1771) where
1772    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1773{
1774    for receiver_id in receiver_ids {
1775        if receiver_id != sender_id {
1776            f(receiver_id).trace_err();
1777        }
1778    }
1779}
1780
1781lazy_static! {
1782    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1783}
1784
1785pub struct ProtocolVersion(u32);
1786
1787impl Header for ProtocolVersion {
1788    fn name() -> &'static HeaderName {
1789        &ZED_PROTOCOL_VERSION
1790    }
1791
1792    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1793    where
1794        Self: Sized,
1795        I: Iterator<Item = &'i axum::http::HeaderValue>,
1796    {
1797        let version = values
1798            .next()
1799            .ok_or_else(|| axum::headers::Error::invalid())?
1800            .to_str()
1801            .map_err(|_| axum::headers::Error::invalid())?
1802            .parse()
1803            .map_err(|_| axum::headers::Error::invalid())?;
1804        Ok(Self(version))
1805    }
1806
1807    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1808        values.extend([self.0.to_string().parse().unwrap()]);
1809    }
1810}
1811
1812pub fn routes(server: Arc<Server>) -> Router<Body> {
1813    Router::new()
1814        .route("/rpc", get(handle_websocket_request))
1815        .layer(
1816            ServiceBuilder::new()
1817                .layer(Extension(server.app_state.clone()))
1818                .layer(middleware::from_fn(auth::validate_header)),
1819        )
1820        .route("/metrics", get(handle_metrics))
1821        .layer(Extension(server))
1822}
1823
1824pub async fn handle_websocket_request(
1825    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1826    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1827    Extension(server): Extension<Arc<Server>>,
1828    Extension(user): Extension<User>,
1829    ws: WebSocketUpgrade,
1830) -> axum::response::Response {
1831    if protocol_version != rpc::PROTOCOL_VERSION {
1832        return (
1833            StatusCode::UPGRADE_REQUIRED,
1834            "client must be upgraded".to_string(),
1835        )
1836            .into_response();
1837    }
1838    let socket_address = socket_address.to_string();
1839    ws.on_upgrade(move |socket| {
1840        use util::ResultExt;
1841        let socket = socket
1842            .map_ok(to_tungstenite_message)
1843            .err_into()
1844            .with(|message| async move { Ok(to_axum_message(message)) });
1845        let connection = Connection::new(Box::pin(socket));
1846        async move {
1847            server
1848                .handle_connection(connection, socket_address, user, None, RealExecutor)
1849                .await
1850                .log_err();
1851        }
1852    })
1853}
1854
1855pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1856    // We call `store_mut` here for its side effects of updating metrics.
1857    let metrics = server.store().await.metrics();
1858    METRIC_CONNECTIONS.set(metrics.connections as _);
1859    METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1860    METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1861    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1862
1863    let encoder = prometheus::TextEncoder::new();
1864    let metric_families = prometheus::gather();
1865    match encoder.encode_to_string(&metric_families) {
1866        Ok(string) => (StatusCode::OK, string).into_response(),
1867        Err(error) => (
1868            StatusCode::INTERNAL_SERVER_ERROR,
1869            format!("failed to encode metrics {:?}", error),
1870        )
1871            .into_response(),
1872    }
1873}
1874
1875fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1876    match message {
1877        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1878        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1879        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1880        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1881        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1882            code: frame.code.into(),
1883            reason: frame.reason,
1884        })),
1885    }
1886}
1887
1888fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1889    match message {
1890        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1891        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1892        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1893        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1894        AxumMessage::Close(frame) => {
1895            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1896                code: frame.code.into(),
1897                reason: frame.reason,
1898            }))
1899        }
1900    }
1901}
1902
1903pub trait ResultExt {
1904    type Ok;
1905
1906    fn trace_err(self) -> Option<Self::Ok>;
1907}
1908
1909impl<T, E> ResultExt for Result<T, E>
1910where
1911    E: std::fmt::Debug,
1912{
1913    type Ok = T;
1914
1915    fn trace_err(self) -> Option<T> {
1916        match self {
1917            Ok(value) => Some(value),
1918            Err(error) => {
1919                tracing::error!("{:?}", error);
1920                None
1921            }
1922        }
1923    }
1924}