rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, ProjectId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52use time::OffsetDateTime;
  53use tokio::{
  54    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub use store::{Store, Worktree};
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_REGISTERED_PROJECTS: IntGauge =
  66        register_int_gauge!("registered_projects", "number of registered projects").unwrap();
  67    static ref METRIC_ACTIVE_PROJECTS: IntGauge =
  68        register_int_gauge!("active_projects", "number of active projects").unwrap();
  69    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  70        "shared_projects",
  71        "number of open projects with one or more guests"
  72    )
  73    .unwrap();
  74}
  75
  76type MessageHandler =
  77    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  78
  79struct Response<R> {
  80    server: Arc<Server>,
  81    receipt: Receipt<R>,
  82    responded: Arc<AtomicBool>,
  83}
  84
  85impl<R: RequestMessage> Response<R> {
  86    fn send(self, payload: R::Response) -> Result<()> {
  87        self.responded.store(true, SeqCst);
  88        self.server.peer.respond(self.receipt, payload)?;
  89        Ok(())
  90    }
  91
  92    fn into_receipt(self) -> Receipt<R> {
  93        self.responded.store(true, SeqCst);
  94        self.receipt
  95    }
  96}
  97
  98pub struct Server {
  99    peer: Arc<Peer>,
 100    pub(crate) store: RwLock<Store>,
 101    app_state: Arc<AppState>,
 102    handlers: HashMap<TypeId, MessageHandler>,
 103    notifications: Option<mpsc::UnboundedSender<()>>,
 104}
 105
 106pub trait Executor: Send + Clone {
 107    type Sleep: Send + Future;
 108    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 109    fn sleep(&self, duration: Duration) -> Self::Sleep;
 110}
 111
 112#[derive(Clone)]
 113pub struct RealExecutor;
 114
 115const MESSAGE_COUNT_PER_PAGE: usize = 100;
 116const MAX_MESSAGE_LEN: usize = 1024;
 117
 118struct StoreReadGuard<'a> {
 119    guard: RwLockReadGuard<'a, Store>,
 120    _not_send: PhantomData<Rc<()>>,
 121}
 122
 123struct StoreWriteGuard<'a> {
 124    guard: RwLockWriteGuard<'a, Store>,
 125    _not_send: PhantomData<Rc<()>>,
 126}
 127
 128#[derive(Serialize)]
 129pub struct ServerSnapshot<'a> {
 130    peer: &'a Peer,
 131    #[serde(serialize_with = "serialize_deref")]
 132    store: RwLockReadGuard<'a, Store>,
 133}
 134
 135pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 136where
 137    S: Serializer,
 138    T: Deref<Target = U>,
 139    U: Serialize,
 140{
 141    Serialize::serialize(value.deref(), serializer)
 142}
 143
 144impl Server {
 145    pub fn new(
 146        app_state: Arc<AppState>,
 147        notifications: Option<mpsc::UnboundedSender<()>>,
 148    ) -> Arc<Self> {
 149        let mut server = Self {
 150            peer: Peer::new(),
 151            app_state,
 152            store: Default::default(),
 153            handlers: Default::default(),
 154            notifications,
 155        };
 156
 157        server
 158            .add_request_handler(Server::ping)
 159            .add_request_handler(Server::register_project)
 160            .add_request_handler(Server::unregister_project)
 161            .add_request_handler(Server::join_project)
 162            .add_message_handler(Server::leave_project)
 163            .add_message_handler(Server::respond_to_join_project_request)
 164            .add_message_handler(Server::update_project)
 165            .add_message_handler(Server::register_project_activity)
 166            .add_request_handler(Server::update_worktree)
 167            .add_message_handler(Server::start_language_server)
 168            .add_message_handler(Server::update_language_server)
 169            .add_message_handler(Server::update_diagnostic_summary)
 170            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 172            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 173            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 174            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 175            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 176            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 177            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 178            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 179            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 180            .add_request_handler(
 181                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 182            )
 183            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 184            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 185            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 186            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 187            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 188            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 189            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 190            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 191            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 192            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 193            .add_request_handler(Server::update_buffer)
 194            .add_message_handler(Server::update_buffer_file)
 195            .add_message_handler(Server::buffer_reloaded)
 196            .add_message_handler(Server::buffer_saved)
 197            .add_request_handler(Server::save_buffer)
 198            .add_request_handler(Server::get_channels)
 199            .add_request_handler(Server::get_users)
 200            .add_request_handler(Server::fuzzy_search_users)
 201            .add_request_handler(Server::request_contact)
 202            .add_request_handler(Server::remove_contact)
 203            .add_request_handler(Server::respond_to_contact_request)
 204            .add_request_handler(Server::join_channel)
 205            .add_message_handler(Server::leave_channel)
 206            .add_request_handler(Server::send_channel_message)
 207            .add_request_handler(Server::follow)
 208            .add_message_handler(Server::unfollow)
 209            .add_message_handler(Server::update_followers)
 210            .add_request_handler(Server::get_channel_messages);
 211
 212        Arc::new(server)
 213    }
 214
 215    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 216    where
 217        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 218        Fut: 'static + Send + Future<Output = Result<()>>,
 219        M: EnvelopedMessage,
 220    {
 221        let prev_handler = self.handlers.insert(
 222            TypeId::of::<M>(),
 223            Box::new(move |server, envelope| {
 224                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 225                let span = info_span!(
 226                    "handle message",
 227                    payload_type = envelope.payload_type_name()
 228                );
 229                span.in_scope(|| {
 230                    tracing::info!(
 231                        payload_type = envelope.payload_type_name(),
 232                        "message received"
 233                    );
 234                });
 235                let future = (handler)(server, *envelope);
 236                async move {
 237                    if let Err(error) = future.await {
 238                        tracing::error!(%error, "error handling message");
 239                    }
 240                }
 241                .instrument(span)
 242                .boxed()
 243            }),
 244        );
 245        if prev_handler.is_some() {
 246            panic!("registered a handler for the same message twice");
 247        }
 248        self
 249    }
 250
 251    /// Handle a request while holding a lock to the store. This is useful when we're registering
 252    /// a connection but we want to respond on the connection before anybody else can send on it.
 253    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 254    where
 255        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 256        Fut: Send + Future<Output = Result<()>>,
 257        M: RequestMessage,
 258    {
 259        let handler = Arc::new(handler);
 260        self.add_message_handler(move |server, envelope| {
 261            let receipt = envelope.receipt();
 262            let handler = handler.clone();
 263            async move {
 264                let responded = Arc::new(AtomicBool::default());
 265                let response = Response {
 266                    server: server.clone(),
 267                    responded: responded.clone(),
 268                    receipt: envelope.receipt(),
 269                };
 270                match (handler)(server.clone(), envelope, response).await {
 271                    Ok(()) => {
 272                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 273                            Ok(())
 274                        } else {
 275                            Err(anyhow!("handler did not send a response"))?
 276                        }
 277                    }
 278                    Err(error) => {
 279                        server.peer.respond_with_error(
 280                            receipt,
 281                            proto::Error {
 282                                message: error.to_string(),
 283                            },
 284                        )?;
 285                        Err(error)
 286                    }
 287                }
 288            }
 289        })
 290    }
 291
 292    /// Start a long lived task that records which users are active in which projects.
 293    pub fn start_recording_project_activity<E: 'static + Executor>(
 294        self: &Arc<Self>,
 295        interval: Duration,
 296        executor: E,
 297    ) {
 298        executor.spawn_detached({
 299            let this = Arc::downgrade(self);
 300            let executor = executor.clone();
 301            async move {
 302                let mut period_start = OffsetDateTime::now_utc();
 303                let mut active_projects = Vec::<(UserId, ProjectId)>::new();
 304                loop {
 305                    let sleep = executor.sleep(interval);
 306                    sleep.await;
 307                    let this = if let Some(this) = this.upgrade() {
 308                        this
 309                    } else {
 310                        break;
 311                    };
 312
 313                    active_projects.clear();
 314                    active_projects.extend(this.store().await.projects().flat_map(
 315                        |(project_id, project)| {
 316                            project.guests.values().chain([&project.host]).filter_map(
 317                                |collaborator| {
 318                                    if !collaborator.admin
 319                                        && collaborator
 320                                            .last_activity
 321                                            .map_or(false, |activity| activity > period_start)
 322                                    {
 323                                        Some((collaborator.user_id, *project_id))
 324                                    } else {
 325                                        None
 326                                    }
 327                                },
 328                            )
 329                        },
 330                    ));
 331
 332                    let period_end = OffsetDateTime::now_utc();
 333                    this.app_state
 334                        .db
 335                        .record_user_activity(period_start..period_end, &active_projects)
 336                        .await
 337                        .trace_err();
 338                    period_start = period_end;
 339                }
 340            }
 341        });
 342    }
 343
 344    pub fn handle_connection<E: Executor>(
 345        self: &Arc<Self>,
 346        connection: Connection,
 347        address: String,
 348        user: User,
 349        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 350        executor: E,
 351    ) -> impl Future<Output = Result<()>> {
 352        let mut this = self.clone();
 353        let user_id = user.id;
 354        let login = user.github_login;
 355        let span = info_span!("handle connection", %user_id, %login, %address);
 356        async move {
 357            let (connection_id, handle_io, mut incoming_rx) = this
 358                .peer
 359                .add_connection(connection, {
 360                    let executor = executor.clone();
 361                    move |duration| {
 362                        let timer = executor.sleep(duration);
 363                        async move {
 364                            timer.await;
 365                        }
 366                    }
 367                })
 368                .await;
 369
 370            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 371
 372            if let Some(send_connection_id) = send_connection_id.as_mut() {
 373                let _ = send_connection_id.send(connection_id).await;
 374            }
 375
 376            if !user.connected_once {
 377                this.peer.send(connection_id, proto::ShowContacts {})?;
 378                this.app_state.db.set_user_connected_once(user_id, true).await?;
 379            }
 380
 381            let (contacts, invite_code) = future::try_join(
 382                this.app_state.db.get_contacts(user_id),
 383                this.app_state.db.get_invite_code_for_user(user_id)
 384            ).await?;
 385
 386            {
 387                let mut store = this.store_mut().await;
 388                store.add_connection(connection_id, user_id, user.admin);
 389                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 390
 391                if let Some((code, count)) = invite_code {
 392                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 393                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 394                        count,
 395                    })?;
 396                }
 397            }
 398            this.update_user_contacts(user_id).await?;
 399
 400            let handle_io = handle_io.fuse();
 401            futures::pin_mut!(handle_io);
 402
 403            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 404            // This prevents deadlocks when e.g., client A performs a request to client B and
 405            // client B performs a request to client A. If both clients stop processing further
 406            // messages until their respective request completes, they won't have a chance to
 407            // respond to the other client's request and cause a deadlock.
 408            //
 409            // This arrangement ensures we will attempt to process earlier messages first, but fall
 410            // back to processing messages arrived later in the spirit of making progress.
 411            let mut foreground_message_handlers = FuturesUnordered::new();
 412            loop {
 413                let next_message = incoming_rx.next().fuse();
 414                futures::pin_mut!(next_message);
 415                futures::select_biased! {
 416                    result = handle_io => {
 417                        if let Err(error) = result {
 418                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 419                        }
 420                        break;
 421                    }
 422                    _ = foreground_message_handlers.next() => {}
 423                    message = next_message => {
 424                        if let Some(message) = message {
 425                            let type_name = message.payload_type_name();
 426                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 427                            let span_enter = span.enter();
 428                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 429                                let notifications = this.notifications.clone();
 430                                let is_background = message.is_background();
 431                                let handle_message = (handler)(this.clone(), message);
 432
 433                                drop(span_enter);
 434                                let handle_message = async move {
 435                                    handle_message.await;
 436                                    if let Some(mut notifications) = notifications {
 437                                        let _ = notifications.send(()).await;
 438                                    }
 439                                }.instrument(span);
 440
 441                                if is_background {
 442                                    executor.spawn_detached(handle_message);
 443                                } else {
 444                                    foreground_message_handlers.push(handle_message);
 445                                }
 446                            } else {
 447                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 448                            }
 449                        } else {
 450                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 451                            break;
 452                        }
 453                    }
 454                }
 455            }
 456
 457            drop(foreground_message_handlers);
 458            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 459            if let Err(error) = this.sign_out(connection_id).await {
 460                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 461            }
 462
 463            Ok(())
 464        }.instrument(span)
 465    }
 466
 467    #[instrument(skip(self), err)]
 468    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 469        self.peer.disconnect(connection_id);
 470
 471        let mut projects_to_unregister = Vec::new();
 472        let removed_user_id;
 473        {
 474            let mut store = self.store_mut().await;
 475            let removed_connection = store.remove_connection(connection_id)?;
 476
 477            for (project_id, project) in removed_connection.hosted_projects {
 478                projects_to_unregister.push(project_id);
 479                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 480                    self.peer.send(
 481                        conn_id,
 482                        proto::UnregisterProject {
 483                            project_id: project_id.to_proto(),
 484                        },
 485                    )
 486                });
 487
 488                for (_, receipts) in project.join_requests {
 489                    for receipt in receipts {
 490                        self.peer.respond(
 491                            receipt,
 492                            proto::JoinProjectResponse {
 493                                variant: Some(proto::join_project_response::Variant::Decline(
 494                                    proto::join_project_response::Decline {
 495                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 496                                    },
 497                                )),
 498                            },
 499                        )?;
 500                    }
 501                }
 502            }
 503
 504            for project_id in removed_connection.guest_project_ids {
 505                if let Some(project) = store.project(project_id).trace_err() {
 506                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 507                        self.peer.send(
 508                            conn_id,
 509                            proto::RemoveProjectCollaborator {
 510                                project_id: project_id.to_proto(),
 511                                peer_id: connection_id.0,
 512                            },
 513                        )
 514                    });
 515                    if project.guests.is_empty() {
 516                        self.peer
 517                            .send(
 518                                project.host_connection_id,
 519                                proto::ProjectUnshared {
 520                                    project_id: project_id.to_proto(),
 521                                },
 522                            )
 523                            .trace_err();
 524                    }
 525                }
 526            }
 527
 528            removed_user_id = removed_connection.user_id;
 529        };
 530
 531        self.update_user_contacts(removed_user_id).await.trace_err();
 532
 533        for project_id in projects_to_unregister {
 534            self.app_state
 535                .db
 536                .unregister_project(project_id)
 537                .await
 538                .trace_err();
 539        }
 540
 541        Ok(())
 542    }
 543
 544    pub async fn invite_code_redeemed(
 545        self: &Arc<Self>,
 546        code: &str,
 547        invitee_id: UserId,
 548    ) -> Result<()> {
 549        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 550        let store = self.store().await;
 551        let invitee_contact = store.contact_for_user(invitee_id, true);
 552        for connection_id in store.connection_ids_for_user(user.id) {
 553            self.peer.send(
 554                connection_id,
 555                proto::UpdateContacts {
 556                    contacts: vec![invitee_contact.clone()],
 557                    ..Default::default()
 558                },
 559            )?;
 560            self.peer.send(
 561                connection_id,
 562                proto::UpdateInviteInfo {
 563                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 564                    count: user.invite_count as u32,
 565                },
 566            )?;
 567        }
 568        Ok(())
 569    }
 570
 571    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 572        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 573            if let Some(invite_code) = &user.invite_code {
 574                let store = self.store().await;
 575                for connection_id in store.connection_ids_for_user(user_id) {
 576                    self.peer.send(
 577                        connection_id,
 578                        proto::UpdateInviteInfo {
 579                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 580                            count: user.invite_count as u32,
 581                        },
 582                    )?;
 583                }
 584            }
 585        }
 586        Ok(())
 587    }
 588
 589    async fn ping(
 590        self: Arc<Server>,
 591        _: TypedEnvelope<proto::Ping>,
 592        response: Response<proto::Ping>,
 593    ) -> Result<()> {
 594        response.send(proto::Ack {})?;
 595        Ok(())
 596    }
 597
 598    async fn register_project(
 599        self: Arc<Server>,
 600        request: TypedEnvelope<proto::RegisterProject>,
 601        response: Response<proto::RegisterProject>,
 602    ) -> Result<()> {
 603        let user_id = self
 604            .store()
 605            .await
 606            .user_id_for_connection(request.sender_id)?;
 607        let project_id = self.app_state.db.register_project(user_id).await?;
 608        self.store_mut().await.register_project(
 609            request.sender_id,
 610            project_id,
 611            request.payload.online,
 612        )?;
 613
 614        response.send(proto::RegisterProjectResponse {
 615            project_id: project_id.to_proto(),
 616        })?;
 617
 618        Ok(())
 619    }
 620
 621    async fn unregister_project(
 622        self: Arc<Server>,
 623        request: TypedEnvelope<proto::UnregisterProject>,
 624        response: Response<proto::UnregisterProject>,
 625    ) -> Result<()> {
 626        let project_id = ProjectId::from_proto(request.payload.project_id);
 627        let (user_id, project) = {
 628            let mut state = self.store_mut().await;
 629            let project = state.unregister_project(project_id, request.sender_id)?;
 630            (state.user_id_for_connection(request.sender_id)?, project)
 631        };
 632        self.app_state.db.unregister_project(project_id).await?;
 633
 634        broadcast(
 635            request.sender_id,
 636            project.guests.keys().copied(),
 637            |conn_id| {
 638                self.peer.send(
 639                    conn_id,
 640                    proto::UnregisterProject {
 641                        project_id: project_id.to_proto(),
 642                    },
 643                )
 644            },
 645        );
 646        for (_, receipts) in project.join_requests {
 647            for receipt in receipts {
 648                self.peer.respond(
 649                    receipt,
 650                    proto::JoinProjectResponse {
 651                        variant: Some(proto::join_project_response::Variant::Decline(
 652                            proto::join_project_response::Decline {
 653                                reason: proto::join_project_response::decline::Reason::Closed
 654                                    as i32,
 655                            },
 656                        )),
 657                    },
 658                )?;
 659            }
 660        }
 661
 662        // Send out the `UpdateContacts` message before responding to the unregister
 663        // request. This way, when the project's host can keep track of the project's
 664        // remote id until after they've received the `UpdateContacts` message for
 665        // themself.
 666        self.update_user_contacts(user_id).await?;
 667        response.send(proto::Ack {})?;
 668
 669        Ok(())
 670    }
 671
 672    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 673        let contacts = self.app_state.db.get_contacts(user_id).await?;
 674        let store = self.store().await;
 675        let updated_contact = store.contact_for_user(user_id, false);
 676        for contact in contacts {
 677            if let db::Contact::Accepted {
 678                user_id: contact_user_id,
 679                ..
 680            } = contact
 681            {
 682                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 683                    self.peer
 684                        .send(
 685                            contact_conn_id,
 686                            proto::UpdateContacts {
 687                                contacts: vec![updated_contact.clone()],
 688                                remove_contacts: Default::default(),
 689                                incoming_requests: Default::default(),
 690                                remove_incoming_requests: Default::default(),
 691                                outgoing_requests: Default::default(),
 692                                remove_outgoing_requests: Default::default(),
 693                            },
 694                        )
 695                        .trace_err();
 696                }
 697            }
 698        }
 699        Ok(())
 700    }
 701
 702    async fn join_project(
 703        self: Arc<Server>,
 704        request: TypedEnvelope<proto::JoinProject>,
 705        response: Response<proto::JoinProject>,
 706    ) -> Result<()> {
 707        let project_id = ProjectId::from_proto(request.payload.project_id);
 708
 709        let host_user_id;
 710        let guest_user_id;
 711        let host_connection_id;
 712        {
 713            let state = self.store().await;
 714            let project = state.project(project_id)?;
 715            host_user_id = project.host.user_id;
 716            host_connection_id = project.host_connection_id;
 717            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 718        };
 719
 720        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 721        let has_contact = self
 722            .app_state
 723            .db
 724            .has_contact(guest_user_id, host_user_id)
 725            .await?;
 726        if !has_contact {
 727            return Err(anyhow!("no such project"))?;
 728        }
 729
 730        self.store_mut().await.request_join_project(
 731            guest_user_id,
 732            project_id,
 733            response.into_receipt(),
 734        )?;
 735        self.peer.send(
 736            host_connection_id,
 737            proto::RequestJoinProject {
 738                project_id: project_id.to_proto(),
 739                requester_id: guest_user_id.to_proto(),
 740            },
 741        )?;
 742        Ok(())
 743    }
 744
 745    async fn respond_to_join_project_request(
 746        self: Arc<Server>,
 747        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 748    ) -> Result<()> {
 749        let host_user_id;
 750
 751        {
 752            let mut state = self.store_mut().await;
 753            let project_id = ProjectId::from_proto(request.payload.project_id);
 754            let project = state.project(project_id)?;
 755            if project.host_connection_id != request.sender_id {
 756                Err(anyhow!("no such connection"))?;
 757            }
 758
 759            host_user_id = project.host.user_id;
 760            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 761
 762            if !request.payload.allow {
 763                let receipts = state
 764                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 765                    .ok_or_else(|| anyhow!("no such request"))?;
 766                for receipt in receipts {
 767                    self.peer.respond(
 768                        receipt,
 769                        proto::JoinProjectResponse {
 770                            variant: Some(proto::join_project_response::Variant::Decline(
 771                                proto::join_project_response::Decline {
 772                                    reason: proto::join_project_response::decline::Reason::Declined
 773                                        as i32,
 774                                },
 775                            )),
 776                        },
 777                    )?;
 778                }
 779                return Ok(());
 780            }
 781
 782            let (receipts_with_replica_ids, project) = state
 783                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 784                .ok_or_else(|| anyhow!("no such request"))?;
 785
 786            let peer_count = project.guests.len();
 787            let mut collaborators = Vec::with_capacity(peer_count);
 788            collaborators.push(proto::Collaborator {
 789                peer_id: project.host_connection_id.0,
 790                replica_id: 0,
 791                user_id: project.host.user_id.to_proto(),
 792            });
 793            let worktrees = project
 794                .worktrees
 795                .iter()
 796                .filter_map(|(id, shared_worktree)| {
 797                    let worktree = project.worktrees.get(&id)?;
 798                    Some(proto::Worktree {
 799                        id: *id,
 800                        root_name: worktree.root_name.clone(),
 801                        entries: shared_worktree.entries.values().cloned().collect(),
 802                        diagnostic_summaries: shared_worktree
 803                            .diagnostic_summaries
 804                            .values()
 805                            .cloned()
 806                            .collect(),
 807                        visible: worktree.visible,
 808                        scan_id: shared_worktree.scan_id,
 809                    })
 810                })
 811                .collect::<Vec<_>>();
 812
 813            // Add all guests other than the requesting user's own connections as collaborators
 814            for (guest_conn_id, guest) in &project.guests {
 815                if receipts_with_replica_ids
 816                    .iter()
 817                    .all(|(receipt, _)| receipt.sender_id != *guest_conn_id)
 818                {
 819                    collaborators.push(proto::Collaborator {
 820                        peer_id: guest_conn_id.0,
 821                        replica_id: guest.replica_id as u32,
 822                        user_id: guest.user_id.to_proto(),
 823                    });
 824                }
 825            }
 826
 827            for conn_id in project.connection_ids() {
 828                for (receipt, replica_id) in &receipts_with_replica_ids {
 829                    if conn_id != receipt.sender_id {
 830                        self.peer.send(
 831                            conn_id,
 832                            proto::AddProjectCollaborator {
 833                                project_id: project_id.to_proto(),
 834                                collaborator: Some(proto::Collaborator {
 835                                    peer_id: receipt.sender_id.0,
 836                                    replica_id: *replica_id as u32,
 837                                    user_id: guest_user_id.to_proto(),
 838                                }),
 839                            },
 840                        )?;
 841                    }
 842                }
 843            }
 844
 845            for (receipt, replica_id) in receipts_with_replica_ids {
 846                self.peer.respond(
 847                    receipt,
 848                    proto::JoinProjectResponse {
 849                        variant: Some(proto::join_project_response::Variant::Accept(
 850                            proto::join_project_response::Accept {
 851                                worktrees: worktrees.clone(),
 852                                replica_id: replica_id as u32,
 853                                collaborators: collaborators.clone(),
 854                                language_servers: project.language_servers.clone(),
 855                            },
 856                        )),
 857                    },
 858                )?;
 859            }
 860        }
 861
 862        self.update_user_contacts(host_user_id).await?;
 863        Ok(())
 864    }
 865
 866    async fn leave_project(
 867        self: Arc<Server>,
 868        request: TypedEnvelope<proto::LeaveProject>,
 869    ) -> Result<()> {
 870        let sender_id = request.sender_id;
 871        let project_id = ProjectId::from_proto(request.payload.project_id);
 872        let project;
 873        {
 874            let mut store = self.store_mut().await;
 875            project = store.leave_project(sender_id, project_id)?;
 876            tracing::info!(
 877                %project_id,
 878                host_user_id = %project.host_user_id,
 879                host_connection_id = %project.host_connection_id,
 880                "leave project"
 881            );
 882
 883            if project.remove_collaborator {
 884                broadcast(sender_id, project.connection_ids, |conn_id| {
 885                    self.peer.send(
 886                        conn_id,
 887                        proto::RemoveProjectCollaborator {
 888                            project_id: project_id.to_proto(),
 889                            peer_id: sender_id.0,
 890                        },
 891                    )
 892                });
 893            }
 894
 895            if let Some(requester_id) = project.cancel_request {
 896                self.peer.send(
 897                    project.host_connection_id,
 898                    proto::JoinProjectRequestCancelled {
 899                        project_id: project_id.to_proto(),
 900                        requester_id: requester_id.to_proto(),
 901                    },
 902                )?;
 903            }
 904
 905            if project.unshare {
 906                self.peer.send(
 907                    project.host_connection_id,
 908                    proto::ProjectUnshared {
 909                        project_id: project_id.to_proto(),
 910                    },
 911                )?;
 912            }
 913        }
 914        self.update_user_contacts(project.host_user_id).await?;
 915        Ok(())
 916    }
 917
 918    async fn update_project(
 919        self: Arc<Server>,
 920        request: TypedEnvelope<proto::UpdateProject>,
 921    ) -> Result<()> {
 922        let project_id = ProjectId::from_proto(request.payload.project_id);
 923        let user_id;
 924        {
 925            let mut state = self.store_mut().await;
 926            user_id = state.user_id_for_connection(request.sender_id)?;
 927            let guest_connection_ids = state
 928                .read_project(project_id, request.sender_id)?
 929                .guest_connection_ids();
 930            let unshared_project = state.update_project(
 931                project_id,
 932                &request.payload.worktrees,
 933                request.payload.online,
 934                request.sender_id,
 935            )?;
 936
 937            if let Some(unshared_project) = unshared_project {
 938                broadcast(
 939                    request.sender_id,
 940                    unshared_project.guests.keys().copied(),
 941                    |conn_id| {
 942                        self.peer.send(
 943                            conn_id,
 944                            proto::UnregisterProject {
 945                                project_id: project_id.to_proto(),
 946                            },
 947                        )
 948                    },
 949                );
 950                for (_, receipts) in unshared_project.pending_join_requests {
 951                    for receipt in receipts {
 952                        self.peer.respond(
 953                            receipt,
 954                            proto::JoinProjectResponse {
 955                                variant: Some(proto::join_project_response::Variant::Decline(
 956                                    proto::join_project_response::Decline {
 957                                        reason:
 958                                            proto::join_project_response::decline::Reason::Closed
 959                                                as i32,
 960                                    },
 961                                )),
 962                            },
 963                        )?;
 964                    }
 965                }
 966            } else {
 967                broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 968                    self.peer.forward_send(
 969                        request.sender_id,
 970                        connection_id,
 971                        request.payload.clone(),
 972                    )
 973                });
 974            }
 975        };
 976
 977        self.update_user_contacts(user_id).await?;
 978        Ok(())
 979    }
 980
 981    async fn register_project_activity(
 982        self: Arc<Server>,
 983        request: TypedEnvelope<proto::RegisterProjectActivity>,
 984    ) -> Result<()> {
 985        self.store_mut().await.register_project_activity(
 986            ProjectId::from_proto(request.payload.project_id),
 987            request.sender_id,
 988        )?;
 989        Ok(())
 990    }
 991
 992    async fn update_worktree(
 993        self: Arc<Server>,
 994        request: TypedEnvelope<proto::UpdateWorktree>,
 995        response: Response<proto::UpdateWorktree>,
 996    ) -> Result<()> {
 997        let project_id = ProjectId::from_proto(request.payload.project_id);
 998        let worktree_id = request.payload.worktree_id;
 999        let (connection_ids, metadata_changed, extension_counts) = {
1000            let mut store = self.store_mut().await;
1001            let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
1002                request.sender_id,
1003                project_id,
1004                worktree_id,
1005                &request.payload.root_name,
1006                &request.payload.removed_entries,
1007                &request.payload.updated_entries,
1008                request.payload.scan_id,
1009            )?;
1010            (connection_ids, metadata_changed, extension_counts.clone())
1011        };
1012        self.app_state
1013            .db
1014            .update_worktree_extensions(project_id, worktree_id, extension_counts)
1015            .await?;
1016
1017        broadcast(request.sender_id, connection_ids, |connection_id| {
1018            self.peer
1019                .forward_send(request.sender_id, connection_id, request.payload.clone())
1020        });
1021        if metadata_changed {
1022            let user_id = self
1023                .store()
1024                .await
1025                .user_id_for_connection(request.sender_id)?;
1026            self.update_user_contacts(user_id).await?;
1027        }
1028        response.send(proto::Ack {})?;
1029        Ok(())
1030    }
1031
1032    async fn update_diagnostic_summary(
1033        self: Arc<Server>,
1034        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
1035    ) -> Result<()> {
1036        let summary = request
1037            .payload
1038            .summary
1039            .clone()
1040            .ok_or_else(|| anyhow!("invalid summary"))?;
1041        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
1042            ProjectId::from_proto(request.payload.project_id),
1043            request.payload.worktree_id,
1044            request.sender_id,
1045            summary,
1046        )?;
1047
1048        broadcast(request.sender_id, receiver_ids, |connection_id| {
1049            self.peer
1050                .forward_send(request.sender_id, connection_id, request.payload.clone())
1051        });
1052        Ok(())
1053    }
1054
1055    async fn start_language_server(
1056        self: Arc<Server>,
1057        request: TypedEnvelope<proto::StartLanguageServer>,
1058    ) -> Result<()> {
1059        let receiver_ids = self.store_mut().await.start_language_server(
1060            ProjectId::from_proto(request.payload.project_id),
1061            request.sender_id,
1062            request
1063                .payload
1064                .server
1065                .clone()
1066                .ok_or_else(|| anyhow!("invalid language server"))?,
1067        )?;
1068        broadcast(request.sender_id, receiver_ids, |connection_id| {
1069            self.peer
1070                .forward_send(request.sender_id, connection_id, request.payload.clone())
1071        });
1072        Ok(())
1073    }
1074
1075    async fn update_language_server(
1076        self: Arc<Server>,
1077        request: TypedEnvelope<proto::UpdateLanguageServer>,
1078    ) -> Result<()> {
1079        let receiver_ids = self.store().await.project_connection_ids(
1080            ProjectId::from_proto(request.payload.project_id),
1081            request.sender_id,
1082        )?;
1083        broadcast(request.sender_id, receiver_ids, |connection_id| {
1084            self.peer
1085                .forward_send(request.sender_id, connection_id, request.payload.clone())
1086        });
1087        Ok(())
1088    }
1089
1090    async fn forward_project_request<T>(
1091        self: Arc<Server>,
1092        request: TypedEnvelope<T>,
1093        response: Response<T>,
1094    ) -> Result<()>
1095    where
1096        T: EntityMessage + RequestMessage,
1097    {
1098        let host_connection_id = self
1099            .store()
1100            .await
1101            .read_project(
1102                ProjectId::from_proto(request.payload.remote_entity_id()),
1103                request.sender_id,
1104            )?
1105            .host_connection_id;
1106
1107        response.send(
1108            self.peer
1109                .forward_request(request.sender_id, host_connection_id, request.payload)
1110                .await?,
1111        )?;
1112        Ok(())
1113    }
1114
1115    async fn save_buffer(
1116        self: Arc<Server>,
1117        request: TypedEnvelope<proto::SaveBuffer>,
1118        response: Response<proto::SaveBuffer>,
1119    ) -> Result<()> {
1120        let project_id = ProjectId::from_proto(request.payload.project_id);
1121        let host = self
1122            .store()
1123            .await
1124            .read_project(project_id, request.sender_id)?
1125            .host_connection_id;
1126        let response_payload = self
1127            .peer
1128            .forward_request(request.sender_id, host, request.payload.clone())
1129            .await?;
1130
1131        let mut guests = self
1132            .store()
1133            .await
1134            .read_project(project_id, request.sender_id)?
1135            .connection_ids();
1136        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1137        broadcast(host, guests, |conn_id| {
1138            self.peer
1139                .forward_send(host, conn_id, response_payload.clone())
1140        });
1141        response.send(response_payload)?;
1142        Ok(())
1143    }
1144
1145    async fn update_buffer(
1146        self: Arc<Server>,
1147        request: TypedEnvelope<proto::UpdateBuffer>,
1148        response: Response<proto::UpdateBuffer>,
1149    ) -> Result<()> {
1150        let project_id = ProjectId::from_proto(request.payload.project_id);
1151        let receiver_ids = {
1152            let mut store = self.store_mut().await;
1153            store.register_project_activity(project_id, request.sender_id)?;
1154            store.project_connection_ids(project_id, request.sender_id)?
1155        };
1156
1157        broadcast(request.sender_id, receiver_ids, |connection_id| {
1158            self.peer
1159                .forward_send(request.sender_id, connection_id, request.payload.clone())
1160        });
1161        response.send(proto::Ack {})?;
1162        Ok(())
1163    }
1164
1165    async fn update_buffer_file(
1166        self: Arc<Server>,
1167        request: TypedEnvelope<proto::UpdateBufferFile>,
1168    ) -> Result<()> {
1169        let receiver_ids = self.store().await.project_connection_ids(
1170            ProjectId::from_proto(request.payload.project_id),
1171            request.sender_id,
1172        )?;
1173        broadcast(request.sender_id, receiver_ids, |connection_id| {
1174            self.peer
1175                .forward_send(request.sender_id, connection_id, request.payload.clone())
1176        });
1177        Ok(())
1178    }
1179
1180    async fn buffer_reloaded(
1181        self: Arc<Server>,
1182        request: TypedEnvelope<proto::BufferReloaded>,
1183    ) -> Result<()> {
1184        let receiver_ids = self.store().await.project_connection_ids(
1185            ProjectId::from_proto(request.payload.project_id),
1186            request.sender_id,
1187        )?;
1188        broadcast(request.sender_id, receiver_ids, |connection_id| {
1189            self.peer
1190                .forward_send(request.sender_id, connection_id, request.payload.clone())
1191        });
1192        Ok(())
1193    }
1194
1195    async fn buffer_saved(
1196        self: Arc<Server>,
1197        request: TypedEnvelope<proto::BufferSaved>,
1198    ) -> Result<()> {
1199        let receiver_ids = self.store().await.project_connection_ids(
1200            ProjectId::from_proto(request.payload.project_id),
1201            request.sender_id,
1202        )?;
1203        broadcast(request.sender_id, receiver_ids, |connection_id| {
1204            self.peer
1205                .forward_send(request.sender_id, connection_id, request.payload.clone())
1206        });
1207        Ok(())
1208    }
1209
1210    async fn follow(
1211        self: Arc<Self>,
1212        request: TypedEnvelope<proto::Follow>,
1213        response: Response<proto::Follow>,
1214    ) -> Result<()> {
1215        let project_id = ProjectId::from_proto(request.payload.project_id);
1216        let leader_id = ConnectionId(request.payload.leader_id);
1217        let follower_id = request.sender_id;
1218        {
1219            let mut store = self.store_mut().await;
1220            if !store
1221                .project_connection_ids(project_id, follower_id)?
1222                .contains(&leader_id)
1223            {
1224                Err(anyhow!("no such peer"))?;
1225            }
1226
1227            store.register_project_activity(project_id, follower_id)?;
1228        }
1229
1230        let mut response_payload = self
1231            .peer
1232            .forward_request(request.sender_id, leader_id, request.payload)
1233            .await?;
1234        response_payload
1235            .views
1236            .retain(|view| view.leader_id != Some(follower_id.0));
1237        response.send(response_payload)?;
1238        Ok(())
1239    }
1240
1241    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1242        let project_id = ProjectId::from_proto(request.payload.project_id);
1243        let leader_id = ConnectionId(request.payload.leader_id);
1244        let mut store = self.store_mut().await;
1245        if !store
1246            .project_connection_ids(project_id, request.sender_id)?
1247            .contains(&leader_id)
1248        {
1249            Err(anyhow!("no such peer"))?;
1250        }
1251        store.register_project_activity(project_id, request.sender_id)?;
1252        self.peer
1253            .forward_send(request.sender_id, leader_id, request.payload)?;
1254        Ok(())
1255    }
1256
1257    async fn update_followers(
1258        self: Arc<Self>,
1259        request: TypedEnvelope<proto::UpdateFollowers>,
1260    ) -> Result<()> {
1261        let project_id = ProjectId::from_proto(request.payload.project_id);
1262        let mut store = self.store_mut().await;
1263        store.register_project_activity(project_id, request.sender_id)?;
1264        let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
1265        let leader_id = request
1266            .payload
1267            .variant
1268            .as_ref()
1269            .and_then(|variant| match variant {
1270                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1271                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1272                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1273            });
1274        for follower_id in &request.payload.follower_ids {
1275            let follower_id = ConnectionId(*follower_id);
1276            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1277                self.peer
1278                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1279            }
1280        }
1281        Ok(())
1282    }
1283
1284    async fn get_channels(
1285        self: Arc<Server>,
1286        request: TypedEnvelope<proto::GetChannels>,
1287        response: Response<proto::GetChannels>,
1288    ) -> Result<()> {
1289        let user_id = self
1290            .store()
1291            .await
1292            .user_id_for_connection(request.sender_id)?;
1293        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1294        response.send(proto::GetChannelsResponse {
1295            channels: channels
1296                .into_iter()
1297                .map(|chan| proto::Channel {
1298                    id: chan.id.to_proto(),
1299                    name: chan.name,
1300                })
1301                .collect(),
1302        })?;
1303        Ok(())
1304    }
1305
1306    async fn get_users(
1307        self: Arc<Server>,
1308        request: TypedEnvelope<proto::GetUsers>,
1309        response: Response<proto::GetUsers>,
1310    ) -> Result<()> {
1311        let user_ids = request
1312            .payload
1313            .user_ids
1314            .into_iter()
1315            .map(UserId::from_proto)
1316            .collect();
1317        let users = self
1318            .app_state
1319            .db
1320            .get_users_by_ids(user_ids)
1321            .await?
1322            .into_iter()
1323            .map(|user| proto::User {
1324                id: user.id.to_proto(),
1325                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1326                github_login: user.github_login,
1327            })
1328            .collect();
1329        response.send(proto::UsersResponse { users })?;
1330        Ok(())
1331    }
1332
1333    async fn fuzzy_search_users(
1334        self: Arc<Server>,
1335        request: TypedEnvelope<proto::FuzzySearchUsers>,
1336        response: Response<proto::FuzzySearchUsers>,
1337    ) -> Result<()> {
1338        let user_id = self
1339            .store()
1340            .await
1341            .user_id_for_connection(request.sender_id)?;
1342        let query = request.payload.query;
1343        let db = &self.app_state.db;
1344        let users = match query.len() {
1345            0 => vec![],
1346            1 | 2 => db
1347                .get_user_by_github_login(&query)
1348                .await?
1349                .into_iter()
1350                .collect(),
1351            _ => db.fuzzy_search_users(&query, 10).await?,
1352        };
1353        let users = users
1354            .into_iter()
1355            .filter(|user| user.id != user_id)
1356            .map(|user| proto::User {
1357                id: user.id.to_proto(),
1358                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1359                github_login: user.github_login,
1360            })
1361            .collect();
1362        response.send(proto::UsersResponse { users })?;
1363        Ok(())
1364    }
1365
1366    async fn request_contact(
1367        self: Arc<Server>,
1368        request: TypedEnvelope<proto::RequestContact>,
1369        response: Response<proto::RequestContact>,
1370    ) -> Result<()> {
1371        let requester_id = self
1372            .store()
1373            .await
1374            .user_id_for_connection(request.sender_id)?;
1375        let responder_id = UserId::from_proto(request.payload.responder_id);
1376        if requester_id == responder_id {
1377            return Err(anyhow!("cannot add yourself as a contact"))?;
1378        }
1379
1380        self.app_state
1381            .db
1382            .send_contact_request(requester_id, responder_id)
1383            .await?;
1384
1385        // Update outgoing contact requests of requester
1386        let mut update = proto::UpdateContacts::default();
1387        update.outgoing_requests.push(responder_id.to_proto());
1388        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1389            self.peer.send(connection_id, update.clone())?;
1390        }
1391
1392        // Update incoming contact requests of responder
1393        let mut update = proto::UpdateContacts::default();
1394        update
1395            .incoming_requests
1396            .push(proto::IncomingContactRequest {
1397                requester_id: requester_id.to_proto(),
1398                should_notify: true,
1399            });
1400        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1401            self.peer.send(connection_id, update.clone())?;
1402        }
1403
1404        response.send(proto::Ack {})?;
1405        Ok(())
1406    }
1407
1408    async fn respond_to_contact_request(
1409        self: Arc<Server>,
1410        request: TypedEnvelope<proto::RespondToContactRequest>,
1411        response: Response<proto::RespondToContactRequest>,
1412    ) -> Result<()> {
1413        let responder_id = self
1414            .store()
1415            .await
1416            .user_id_for_connection(request.sender_id)?;
1417        let requester_id = UserId::from_proto(request.payload.requester_id);
1418        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1419            self.app_state
1420                .db
1421                .dismiss_contact_notification(responder_id, requester_id)
1422                .await?;
1423        } else {
1424            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1425            self.app_state
1426                .db
1427                .respond_to_contact_request(responder_id, requester_id, accept)
1428                .await?;
1429
1430            let store = self.store().await;
1431            // Update responder with new contact
1432            let mut update = proto::UpdateContacts::default();
1433            if accept {
1434                update
1435                    .contacts
1436                    .push(store.contact_for_user(requester_id, false));
1437            }
1438            update
1439                .remove_incoming_requests
1440                .push(requester_id.to_proto());
1441            for connection_id in store.connection_ids_for_user(responder_id) {
1442                self.peer.send(connection_id, update.clone())?;
1443            }
1444
1445            // Update requester with new contact
1446            let mut update = proto::UpdateContacts::default();
1447            if accept {
1448                update
1449                    .contacts
1450                    .push(store.contact_for_user(responder_id, true));
1451            }
1452            update
1453                .remove_outgoing_requests
1454                .push(responder_id.to_proto());
1455            for connection_id in store.connection_ids_for_user(requester_id) {
1456                self.peer.send(connection_id, update.clone())?;
1457            }
1458        }
1459
1460        response.send(proto::Ack {})?;
1461        Ok(())
1462    }
1463
1464    async fn remove_contact(
1465        self: Arc<Server>,
1466        request: TypedEnvelope<proto::RemoveContact>,
1467        response: Response<proto::RemoveContact>,
1468    ) -> Result<()> {
1469        let requester_id = self
1470            .store()
1471            .await
1472            .user_id_for_connection(request.sender_id)?;
1473        let responder_id = UserId::from_proto(request.payload.user_id);
1474        self.app_state
1475            .db
1476            .remove_contact(requester_id, responder_id)
1477            .await?;
1478
1479        // Update outgoing contact requests of requester
1480        let mut update = proto::UpdateContacts::default();
1481        update
1482            .remove_outgoing_requests
1483            .push(responder_id.to_proto());
1484        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1485            self.peer.send(connection_id, update.clone())?;
1486        }
1487
1488        // Update incoming contact requests of responder
1489        let mut update = proto::UpdateContacts::default();
1490        update
1491            .remove_incoming_requests
1492            .push(requester_id.to_proto());
1493        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1494            self.peer.send(connection_id, update.clone())?;
1495        }
1496
1497        response.send(proto::Ack {})?;
1498        Ok(())
1499    }
1500
1501    async fn join_channel(
1502        self: Arc<Self>,
1503        request: TypedEnvelope<proto::JoinChannel>,
1504        response: Response<proto::JoinChannel>,
1505    ) -> Result<()> {
1506        let user_id = self
1507            .store()
1508            .await
1509            .user_id_for_connection(request.sender_id)?;
1510        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1511        if !self
1512            .app_state
1513            .db
1514            .can_user_access_channel(user_id, channel_id)
1515            .await?
1516        {
1517            Err(anyhow!("access denied"))?;
1518        }
1519
1520        self.store_mut()
1521            .await
1522            .join_channel(request.sender_id, channel_id);
1523        let messages = self
1524            .app_state
1525            .db
1526            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1527            .await?
1528            .into_iter()
1529            .map(|msg| proto::ChannelMessage {
1530                id: msg.id.to_proto(),
1531                body: msg.body,
1532                timestamp: msg.sent_at.unix_timestamp() as u64,
1533                sender_id: msg.sender_id.to_proto(),
1534                nonce: Some(msg.nonce.as_u128().into()),
1535            })
1536            .collect::<Vec<_>>();
1537        response.send(proto::JoinChannelResponse {
1538            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1539            messages,
1540        })?;
1541        Ok(())
1542    }
1543
1544    async fn leave_channel(
1545        self: Arc<Self>,
1546        request: TypedEnvelope<proto::LeaveChannel>,
1547    ) -> Result<()> {
1548        let user_id = self
1549            .store()
1550            .await
1551            .user_id_for_connection(request.sender_id)?;
1552        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1553        if !self
1554            .app_state
1555            .db
1556            .can_user_access_channel(user_id, channel_id)
1557            .await?
1558        {
1559            Err(anyhow!("access denied"))?;
1560        }
1561
1562        self.store_mut()
1563            .await
1564            .leave_channel(request.sender_id, channel_id);
1565
1566        Ok(())
1567    }
1568
1569    async fn send_channel_message(
1570        self: Arc<Self>,
1571        request: TypedEnvelope<proto::SendChannelMessage>,
1572        response: Response<proto::SendChannelMessage>,
1573    ) -> Result<()> {
1574        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1575        let user_id;
1576        let connection_ids;
1577        {
1578            let state = self.store().await;
1579            user_id = state.user_id_for_connection(request.sender_id)?;
1580            connection_ids = state.channel_connection_ids(channel_id)?;
1581        }
1582
1583        // Validate the message body.
1584        let body = request.payload.body.trim().to_string();
1585        if body.len() > MAX_MESSAGE_LEN {
1586            return Err(anyhow!("message is too long"))?;
1587        }
1588        if body.is_empty() {
1589            return Err(anyhow!("message can't be blank"))?;
1590        }
1591
1592        let timestamp = OffsetDateTime::now_utc();
1593        let nonce = request
1594            .payload
1595            .nonce
1596            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1597
1598        let message_id = self
1599            .app_state
1600            .db
1601            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1602            .await?
1603            .to_proto();
1604        let message = proto::ChannelMessage {
1605            sender_id: user_id.to_proto(),
1606            id: message_id,
1607            body,
1608            timestamp: timestamp.unix_timestamp() as u64,
1609            nonce: Some(nonce),
1610        };
1611        broadcast(request.sender_id, connection_ids, |conn_id| {
1612            self.peer.send(
1613                conn_id,
1614                proto::ChannelMessageSent {
1615                    channel_id: channel_id.to_proto(),
1616                    message: Some(message.clone()),
1617                },
1618            )
1619        });
1620        response.send(proto::SendChannelMessageResponse {
1621            message: Some(message),
1622        })?;
1623        Ok(())
1624    }
1625
1626    async fn get_channel_messages(
1627        self: Arc<Self>,
1628        request: TypedEnvelope<proto::GetChannelMessages>,
1629        response: Response<proto::GetChannelMessages>,
1630    ) -> Result<()> {
1631        let user_id = self
1632            .store()
1633            .await
1634            .user_id_for_connection(request.sender_id)?;
1635        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1636        if !self
1637            .app_state
1638            .db
1639            .can_user_access_channel(user_id, channel_id)
1640            .await?
1641        {
1642            Err(anyhow!("access denied"))?;
1643        }
1644
1645        let messages = self
1646            .app_state
1647            .db
1648            .get_channel_messages(
1649                channel_id,
1650                MESSAGE_COUNT_PER_PAGE,
1651                Some(MessageId::from_proto(request.payload.before_message_id)),
1652            )
1653            .await?
1654            .into_iter()
1655            .map(|msg| proto::ChannelMessage {
1656                id: msg.id.to_proto(),
1657                body: msg.body,
1658                timestamp: msg.sent_at.unix_timestamp() as u64,
1659                sender_id: msg.sender_id.to_proto(),
1660                nonce: Some(msg.nonce.as_u128().into()),
1661            })
1662            .collect::<Vec<_>>();
1663        response.send(proto::GetChannelMessagesResponse {
1664            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1665            messages,
1666        })?;
1667        Ok(())
1668    }
1669
1670    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1671        #[cfg(test)]
1672        tokio::task::yield_now().await;
1673        let guard = self.store.read().await;
1674        #[cfg(test)]
1675        tokio::task::yield_now().await;
1676        StoreReadGuard {
1677            guard,
1678            _not_send: PhantomData,
1679        }
1680    }
1681
1682    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1683        #[cfg(test)]
1684        tokio::task::yield_now().await;
1685        let guard = self.store.write().await;
1686        #[cfg(test)]
1687        tokio::task::yield_now().await;
1688        StoreWriteGuard {
1689            guard,
1690            _not_send: PhantomData,
1691        }
1692    }
1693
1694    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1695        ServerSnapshot {
1696            store: self.store.read().await,
1697            peer: &self.peer,
1698        }
1699    }
1700}
1701
1702impl<'a> Deref for StoreReadGuard<'a> {
1703    type Target = Store;
1704
1705    fn deref(&self) -> &Self::Target {
1706        &*self.guard
1707    }
1708}
1709
1710impl<'a> Deref for StoreWriteGuard<'a> {
1711    type Target = Store;
1712
1713    fn deref(&self) -> &Self::Target {
1714        &*self.guard
1715    }
1716}
1717
1718impl<'a> DerefMut for StoreWriteGuard<'a> {
1719    fn deref_mut(&mut self) -> &mut Self::Target {
1720        &mut *self.guard
1721    }
1722}
1723
1724impl<'a> Drop for StoreWriteGuard<'a> {
1725    fn drop(&mut self) {
1726        #[cfg(test)]
1727        self.check_invariants();
1728    }
1729}
1730
1731impl Executor for RealExecutor {
1732    type Sleep = Sleep;
1733
1734    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1735        tokio::task::spawn(future);
1736    }
1737
1738    fn sleep(&self, duration: Duration) -> Self::Sleep {
1739        tokio::time::sleep(duration)
1740    }
1741}
1742
1743fn broadcast<F>(
1744    sender_id: ConnectionId,
1745    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1746    mut f: F,
1747) where
1748    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1749{
1750    for receiver_id in receiver_ids {
1751        if receiver_id != sender_id {
1752            f(receiver_id).trace_err();
1753        }
1754    }
1755}
1756
1757lazy_static! {
1758    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1759}
1760
1761pub struct ProtocolVersion(u32);
1762
1763impl Header for ProtocolVersion {
1764    fn name() -> &'static HeaderName {
1765        &ZED_PROTOCOL_VERSION
1766    }
1767
1768    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1769    where
1770        Self: Sized,
1771        I: Iterator<Item = &'i axum::http::HeaderValue>,
1772    {
1773        let version = values
1774            .next()
1775            .ok_or_else(|| axum::headers::Error::invalid())?
1776            .to_str()
1777            .map_err(|_| axum::headers::Error::invalid())?
1778            .parse()
1779            .map_err(|_| axum::headers::Error::invalid())?;
1780        Ok(Self(version))
1781    }
1782
1783    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1784        values.extend([self.0.to_string().parse().unwrap()]);
1785    }
1786}
1787
1788pub fn routes(server: Arc<Server>) -> Router<Body> {
1789    Router::new()
1790        .route("/rpc", get(handle_websocket_request))
1791        .layer(
1792            ServiceBuilder::new()
1793                .layer(Extension(server.app_state.clone()))
1794                .layer(middleware::from_fn(auth::validate_header)),
1795        )
1796        .route("/metrics", get(handle_metrics))
1797        .layer(Extension(server))
1798}
1799
1800pub async fn handle_websocket_request(
1801    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1802    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1803    Extension(server): Extension<Arc<Server>>,
1804    Extension(user): Extension<User>,
1805    ws: WebSocketUpgrade,
1806) -> axum::response::Response {
1807    if protocol_version != rpc::PROTOCOL_VERSION {
1808        return (
1809            StatusCode::UPGRADE_REQUIRED,
1810            "client must be upgraded".to_string(),
1811        )
1812            .into_response();
1813    }
1814    let socket_address = socket_address.to_string();
1815    ws.on_upgrade(move |socket| {
1816        use util::ResultExt;
1817        let socket = socket
1818            .map_ok(to_tungstenite_message)
1819            .err_into()
1820            .with(|message| async move { Ok(to_axum_message(message)) });
1821        let connection = Connection::new(Box::pin(socket));
1822        async move {
1823            server
1824                .handle_connection(connection, socket_address, user, None, RealExecutor)
1825                .await
1826                .log_err();
1827        }
1828    })
1829}
1830
1831pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1832    // We call `store_mut` here for its side effects of updating metrics.
1833    let metrics = server.store().await.metrics();
1834    METRIC_CONNECTIONS.set(metrics.connections as _);
1835    METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1836    METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1837    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1838
1839    let encoder = prometheus::TextEncoder::new();
1840    let metric_families = prometheus::gather();
1841    match encoder.encode_to_string(&metric_families) {
1842        Ok(string) => (StatusCode::OK, string).into_response(),
1843        Err(error) => (
1844            StatusCode::INTERNAL_SERVER_ERROR,
1845            format!("failed to encode metrics {:?}", error),
1846        )
1847            .into_response(),
1848    }
1849}
1850
1851fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1852    match message {
1853        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1854        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1855        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1856        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1857        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1858            code: frame.code.into(),
1859            reason: frame.reason,
1860        })),
1861    }
1862}
1863
1864fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1865    match message {
1866        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1867        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1868        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1869        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1870        AxumMessage::Close(frame) => {
1871            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1872                code: frame.code.into(),
1873                reason: frame.reason,
1874            }))
1875        }
1876    }
1877}
1878
1879pub trait ResultExt {
1880    type Ok;
1881
1882    fn trace_err(self) -> Option<Self::Ok>;
1883}
1884
1885impl<T, E> ResultExt for Result<T, E>
1886where
1887    E: std::fmt::Debug,
1888{
1889    type Ok = T;
1890
1891    fn trace_err(self) -> Option<T> {
1892        match self {
1893            Ok(value) => Some(value),
1894            Err(error) => {
1895                tracing::error!("{:?}", error);
1896                None
1897            }
1898        }
1899    }
1900}