rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, ProjectId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use prometheus::{register_int_gauge, IntGauge};
  33use rpc::{
  34    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  35    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  36};
  37use serde::{Serialize, Serializer};
  38use std::{
  39    any::TypeId,
  40    future::Future,
  41    marker::PhantomData,
  42    net::SocketAddr,
  43    ops::{Deref, DerefMut},
  44    rc::Rc,
  45    sync::{
  46        atomic::{AtomicBool, Ordering::SeqCst},
  47        Arc,
  48    },
  49    time::Duration,
  50};
  51use time::OffsetDateTime;
  52use tokio::{
  53    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  54    time::Sleep,
  55};
  56use tower::ServiceBuilder;
  57use tracing::{info_span, instrument, Instrument};
  58
  59pub use store::{Store, Worktree};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_REGISTERED_PROJECTS: IntGauge =
  65        register_int_gauge!("registered_projects", "number of registered projects").unwrap();
  66    static ref METRIC_ACTIVE_PROJECTS: IntGauge =
  67        register_int_gauge!("active_projects", "number of active projects").unwrap();
  68    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  69        "shared_projects",
  70        "number of open projects with one or more guests"
  71    )
  72    .unwrap();
  73}
  74
  75type MessageHandler =
  76    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  77
  78struct Response<R> {
  79    server: Arc<Server>,
  80    receipt: Receipt<R>,
  81    responded: Arc<AtomicBool>,
  82}
  83
  84impl<R: RequestMessage> Response<R> {
  85    fn send(self, payload: R::Response) -> Result<()> {
  86        self.responded.store(true, SeqCst);
  87        self.server.peer.respond(self.receipt, payload)?;
  88        Ok(())
  89    }
  90
  91    fn into_receipt(self) -> Receipt<R> {
  92        self.responded.store(true, SeqCst);
  93        self.receipt
  94    }
  95}
  96
  97pub struct Server {
  98    peer: Arc<Peer>,
  99    pub(crate) store: RwLock<Store>,
 100    app_state: Arc<AppState>,
 101    handlers: HashMap<TypeId, MessageHandler>,
 102    notifications: Option<mpsc::UnboundedSender<()>>,
 103}
 104
 105pub trait Executor: Send + Clone {
 106    type Sleep: Send + Future;
 107    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 108    fn sleep(&self, duration: Duration) -> Self::Sleep;
 109}
 110
 111#[derive(Clone)]
 112pub struct RealExecutor;
 113
 114const MESSAGE_COUNT_PER_PAGE: usize = 100;
 115const MAX_MESSAGE_LEN: usize = 1024;
 116
 117struct StoreReadGuard<'a> {
 118    guard: RwLockReadGuard<'a, Store>,
 119    _not_send: PhantomData<Rc<()>>,
 120}
 121
 122struct StoreWriteGuard<'a> {
 123    guard: RwLockWriteGuard<'a, Store>,
 124    _not_send: PhantomData<Rc<()>>,
 125}
 126
 127#[derive(Serialize)]
 128pub struct ServerSnapshot<'a> {
 129    peer: &'a Peer,
 130    #[serde(serialize_with = "serialize_deref")]
 131    store: RwLockReadGuard<'a, Store>,
 132}
 133
 134pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 135where
 136    S: Serializer,
 137    T: Deref<Target = U>,
 138    U: Serialize,
 139{
 140    Serialize::serialize(value.deref(), serializer)
 141}
 142
 143impl Server {
 144    pub fn new(
 145        app_state: Arc<AppState>,
 146        notifications: Option<mpsc::UnboundedSender<()>>,
 147    ) -> Arc<Self> {
 148        let mut server = Self {
 149            peer: Peer::new(),
 150            app_state,
 151            store: Default::default(),
 152            handlers: Default::default(),
 153            notifications,
 154        };
 155
 156        server
 157            .add_request_handler(Server::ping)
 158            .add_request_handler(Server::register_project)
 159            .add_request_handler(Server::unregister_project)
 160            .add_request_handler(Server::join_project)
 161            .add_message_handler(Server::leave_project)
 162            .add_message_handler(Server::respond_to_join_project_request)
 163            .add_message_handler(Server::update_project)
 164            .add_message_handler(Server::register_project_activity)
 165            .add_request_handler(Server::update_worktree)
 166            .add_message_handler(Server::start_language_server)
 167            .add_message_handler(Server::update_language_server)
 168            .add_message_handler(Server::update_diagnostic_summary)
 169            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 170            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 172            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 173            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 174            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 175            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 176            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 177            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 178            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 179            .add_request_handler(
 180                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 181            )
 182            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 183            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 184            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 185            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 186            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 187            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 188            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 189            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 190            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 191            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 192            .add_request_handler(Server::update_buffer)
 193            .add_message_handler(Server::update_buffer_file)
 194            .add_message_handler(Server::buffer_reloaded)
 195            .add_message_handler(Server::buffer_saved)
 196            .add_request_handler(Server::save_buffer)
 197            .add_request_handler(Server::get_channels)
 198            .add_request_handler(Server::get_users)
 199            .add_request_handler(Server::fuzzy_search_users)
 200            .add_request_handler(Server::request_contact)
 201            .add_request_handler(Server::remove_contact)
 202            .add_request_handler(Server::respond_to_contact_request)
 203            .add_request_handler(Server::join_channel)
 204            .add_message_handler(Server::leave_channel)
 205            .add_request_handler(Server::send_channel_message)
 206            .add_request_handler(Server::follow)
 207            .add_message_handler(Server::unfollow)
 208            .add_message_handler(Server::update_followers)
 209            .add_request_handler(Server::get_channel_messages);
 210
 211        Arc::new(server)
 212    }
 213
 214    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 215    where
 216        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 217        Fut: 'static + Send + Future<Output = Result<()>>,
 218        M: EnvelopedMessage,
 219    {
 220        let prev_handler = self.handlers.insert(
 221            TypeId::of::<M>(),
 222            Box::new(move |server, envelope| {
 223                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 224                let span = info_span!(
 225                    "handle message",
 226                    payload_type = envelope.payload_type_name()
 227                );
 228                span.in_scope(|| {
 229                    tracing::info!(
 230                        payload = format!("{:?}", envelope.payload).as_str(),
 231                        "message payload"
 232                    );
 233                });
 234                let future = (handler)(server, *envelope);
 235                async move {
 236                    if let Err(error) = future.await {
 237                        tracing::error!(%error, "error handling message");
 238                    }
 239                }
 240                .instrument(span)
 241                .boxed()
 242            }),
 243        );
 244        if prev_handler.is_some() {
 245            panic!("registered a handler for the same message twice");
 246        }
 247        self
 248    }
 249
 250    /// Handle a request while holding a lock to the store. This is useful when we're registering
 251    /// a connection but we want to respond on the connection before anybody else can send on it.
 252    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 253    where
 254        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 255        Fut: Send + Future<Output = Result<()>>,
 256        M: RequestMessage,
 257    {
 258        let handler = Arc::new(handler);
 259        self.add_message_handler(move |server, envelope| {
 260            let receipt = envelope.receipt();
 261            let handler = handler.clone();
 262            async move {
 263                let responded = Arc::new(AtomicBool::default());
 264                let response = Response {
 265                    server: server.clone(),
 266                    responded: responded.clone(),
 267                    receipt: envelope.receipt(),
 268                };
 269                match (handler)(server.clone(), envelope, response).await {
 270                    Ok(()) => {
 271                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 272                            Ok(())
 273                        } else {
 274                            Err(anyhow!("handler did not send a response"))?
 275                        }
 276                    }
 277                    Err(error) => {
 278                        server.peer.respond_with_error(
 279                            receipt,
 280                            proto::Error {
 281                                message: error.to_string(),
 282                            },
 283                        )?;
 284                        Err(error)
 285                    }
 286                }
 287            }
 288        })
 289    }
 290
 291    /// Start a long lived task that records which users are active in which projects.
 292    pub fn start_recording_project_activity<E: 'static + Executor>(
 293        self: &Arc<Self>,
 294        interval: Duration,
 295        executor: E,
 296    ) {
 297        executor.spawn_detached({
 298            let this = Arc::downgrade(self);
 299            let executor = executor.clone();
 300            async move {
 301                let mut period_start = OffsetDateTime::now_utc();
 302                let mut active_projects = Vec::<(UserId, ProjectId)>::new();
 303                loop {
 304                    let sleep = executor.sleep(interval);
 305                    sleep.await;
 306                    let this = if let Some(this) = this.upgrade() {
 307                        this
 308                    } else {
 309                        break;
 310                    };
 311
 312                    active_projects.clear();
 313                    active_projects.extend(this.store().await.projects().flat_map(
 314                        |(project_id, project)| {
 315                            project.guests.values().chain([&project.host]).filter_map(
 316                                |collaborator| {
 317                                    if !collaborator.admin
 318                                        && collaborator
 319                                            .last_activity
 320                                            .map_or(false, |activity| activity > period_start)
 321                                    {
 322                                        Some((collaborator.user_id, *project_id))
 323                                    } else {
 324                                        None
 325                                    }
 326                                },
 327                            )
 328                        },
 329                    ));
 330
 331                    let period_end = OffsetDateTime::now_utc();
 332                    this.app_state
 333                        .db
 334                        .record_project_activity(period_start..period_end, &active_projects)
 335                        .await
 336                        .trace_err();
 337                    period_start = period_end;
 338                }
 339            }
 340        });
 341    }
 342
 343    pub fn handle_connection<E: Executor>(
 344        self: &Arc<Self>,
 345        connection: Connection,
 346        address: String,
 347        user: User,
 348        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 349        executor: E,
 350    ) -> impl Future<Output = Result<()>> {
 351        let mut this = self.clone();
 352        let user_id = user.id;
 353        let login = user.github_login;
 354        let span = info_span!("handle connection", %user_id, %login, %address);
 355        async move {
 356            let (connection_id, handle_io, mut incoming_rx) = this
 357                .peer
 358                .add_connection(connection, {
 359                    let executor = executor.clone();
 360                    move |duration| {
 361                        let timer = executor.sleep(duration);
 362                        async move {
 363                            timer.await;
 364                        }
 365                    }
 366                })
 367                .await;
 368
 369            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 370
 371            if let Some(send_connection_id) = send_connection_id.as_mut() {
 372                let _ = send_connection_id.send(connection_id).await;
 373            }
 374
 375            if !user.connected_once {
 376                this.peer.send(connection_id, proto::ShowContacts {})?;
 377                this.app_state.db.set_user_connected_once(user_id, true).await?;
 378            }
 379
 380            let (contacts, invite_code) = future::try_join(
 381                this.app_state.db.get_contacts(user_id),
 382                this.app_state.db.get_invite_code_for_user(user_id)
 383            ).await?;
 384
 385            {
 386                let mut store = this.store_mut().await;
 387                store.add_connection(connection_id, user_id, user.admin);
 388                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 389
 390                if let Some((code, count)) = invite_code {
 391                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 392                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 393                        count,
 394                    })?;
 395                }
 396            }
 397            this.update_user_contacts(user_id).await?;
 398
 399            let handle_io = handle_io.fuse();
 400            futures::pin_mut!(handle_io);
 401            loop {
 402                let next_message = incoming_rx.next().fuse();
 403                futures::pin_mut!(next_message);
 404                futures::select_biased! {
 405                    result = handle_io => {
 406                        if let Err(error) = result {
 407                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 408                        }
 409                        break;
 410                    }
 411                    message = next_message => {
 412                        if let Some(message) = message {
 413                            let type_name = message.payload_type_name();
 414                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 415                            async {
 416                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 417                                    let notifications = this.notifications.clone();
 418                                    let is_background = message.is_background();
 419                                    let handle_message = (handler)(this.clone(), message);
 420                                    let handle_message = async move {
 421                                        handle_message.await;
 422                                        if let Some(mut notifications) = notifications {
 423                                            let _ = notifications.send(()).await;
 424                                        }
 425                                    };
 426                                    if is_background {
 427                                        executor.spawn_detached(handle_message);
 428                                    } else {
 429                                        handle_message.await;
 430                                    }
 431                                } else {
 432                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 433                                }
 434                            }.instrument(span).await;
 435                        } else {
 436                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 437                            break;
 438                        }
 439                    }
 440                }
 441            }
 442
 443            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 444            if let Err(error) = this.sign_out(connection_id).await {
 445                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 446            }
 447
 448            Ok(())
 449        }.instrument(span)
 450    }
 451
 452    #[instrument(skip(self), err)]
 453    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 454        self.peer.disconnect(connection_id);
 455
 456        let mut projects_to_unregister = Vec::new();
 457        let removed_user_id;
 458        {
 459            let mut store = self.store_mut().await;
 460            let removed_connection = store.remove_connection(connection_id)?;
 461
 462            for (project_id, project) in removed_connection.hosted_projects {
 463                projects_to_unregister.push(project_id);
 464                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 465                    self.peer.send(
 466                        conn_id,
 467                        proto::UnregisterProject {
 468                            project_id: project_id.to_proto(),
 469                        },
 470                    )
 471                });
 472
 473                for (_, receipts) in project.join_requests {
 474                    for receipt in receipts {
 475                        self.peer.respond(
 476                            receipt,
 477                            proto::JoinProjectResponse {
 478                                variant: Some(proto::join_project_response::Variant::Decline(
 479                                    proto::join_project_response::Decline {
 480                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 481                                    },
 482                                )),
 483                            },
 484                        )?;
 485                    }
 486                }
 487            }
 488
 489            for project_id in removed_connection.guest_project_ids {
 490                if let Some(project) = store.project(project_id).trace_err() {
 491                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 492                        self.peer.send(
 493                            conn_id,
 494                            proto::RemoveProjectCollaborator {
 495                                project_id: project_id.to_proto(),
 496                                peer_id: connection_id.0,
 497                            },
 498                        )
 499                    });
 500                    if project.guests.is_empty() {
 501                        self.peer
 502                            .send(
 503                                project.host_connection_id,
 504                                proto::ProjectUnshared {
 505                                    project_id: project_id.to_proto(),
 506                                },
 507                            )
 508                            .trace_err();
 509                    }
 510                }
 511            }
 512
 513            removed_user_id = removed_connection.user_id;
 514        };
 515
 516        self.update_user_contacts(removed_user_id).await.trace_err();
 517
 518        for project_id in projects_to_unregister {
 519            self.app_state
 520                .db
 521                .unregister_project(project_id)
 522                .await
 523                .trace_err();
 524        }
 525
 526        Ok(())
 527    }
 528
 529    pub async fn invite_code_redeemed(
 530        self: &Arc<Self>,
 531        code: &str,
 532        invitee_id: UserId,
 533    ) -> Result<()> {
 534        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 535        let store = self.store().await;
 536        let invitee_contact = store.contact_for_user(invitee_id, true);
 537        for connection_id in store.connection_ids_for_user(user.id) {
 538            self.peer.send(
 539                connection_id,
 540                proto::UpdateContacts {
 541                    contacts: vec![invitee_contact.clone()],
 542                    ..Default::default()
 543                },
 544            )?;
 545            self.peer.send(
 546                connection_id,
 547                proto::UpdateInviteInfo {
 548                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 549                    count: user.invite_count as u32,
 550                },
 551            )?;
 552        }
 553        Ok(())
 554    }
 555
 556    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 557        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 558            if let Some(invite_code) = &user.invite_code {
 559                let store = self.store().await;
 560                for connection_id in store.connection_ids_for_user(user_id) {
 561                    self.peer.send(
 562                        connection_id,
 563                        proto::UpdateInviteInfo {
 564                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 565                            count: user.invite_count as u32,
 566                        },
 567                    )?;
 568                }
 569            }
 570        }
 571        Ok(())
 572    }
 573
 574    async fn ping(
 575        self: Arc<Server>,
 576        _: TypedEnvelope<proto::Ping>,
 577        response: Response<proto::Ping>,
 578    ) -> Result<()> {
 579        response.send(proto::Ack {})?;
 580        Ok(())
 581    }
 582
 583    async fn register_project(
 584        self: Arc<Server>,
 585        request: TypedEnvelope<proto::RegisterProject>,
 586        response: Response<proto::RegisterProject>,
 587    ) -> Result<()> {
 588        let user_id = self
 589            .store()
 590            .await
 591            .user_id_for_connection(request.sender_id)?;
 592        let project_id = self.app_state.db.register_project(user_id).await?;
 593        self.store_mut()
 594            .await
 595            .register_project(request.sender_id, project_id)?;
 596
 597        response.send(proto::RegisterProjectResponse {
 598            project_id: project_id.to_proto(),
 599        })?;
 600
 601        Ok(())
 602    }
 603
 604    async fn unregister_project(
 605        self: Arc<Server>,
 606        request: TypedEnvelope<proto::UnregisterProject>,
 607        response: Response<proto::UnregisterProject>,
 608    ) -> Result<()> {
 609        let project_id = ProjectId::from_proto(request.payload.project_id);
 610        let (user_id, project) = {
 611            let mut state = self.store_mut().await;
 612            let project = state.unregister_project(project_id, request.sender_id)?;
 613            (state.user_id_for_connection(request.sender_id)?, project)
 614        };
 615        self.app_state.db.unregister_project(project_id).await?;
 616
 617        broadcast(
 618            request.sender_id,
 619            project.guests.keys().copied(),
 620            |conn_id| {
 621                self.peer.send(
 622                    conn_id,
 623                    proto::UnregisterProject {
 624                        project_id: project_id.to_proto(),
 625                    },
 626                )
 627            },
 628        );
 629        for (_, receipts) in project.join_requests {
 630            for receipt in receipts {
 631                self.peer.respond(
 632                    receipt,
 633                    proto::JoinProjectResponse {
 634                        variant: Some(proto::join_project_response::Variant::Decline(
 635                            proto::join_project_response::Decline {
 636                                reason: proto::join_project_response::decline::Reason::Closed
 637                                    as i32,
 638                            },
 639                        )),
 640                    },
 641                )?;
 642            }
 643        }
 644
 645        // Send out the `UpdateContacts` message before responding to the unregister
 646        // request. This way, when the project's host can keep track of the project's
 647        // remote id until after they've received the `UpdateContacts` message for
 648        // themself.
 649        self.update_user_contacts(user_id).await?;
 650        response.send(proto::Ack {})?;
 651
 652        Ok(())
 653    }
 654
 655    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 656        let contacts = self.app_state.db.get_contacts(user_id).await?;
 657        let store = self.store().await;
 658        let updated_contact = store.contact_for_user(user_id, false);
 659        for contact in contacts {
 660            if let db::Contact::Accepted {
 661                user_id: contact_user_id,
 662                ..
 663            } = contact
 664            {
 665                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 666                    self.peer
 667                        .send(
 668                            contact_conn_id,
 669                            proto::UpdateContacts {
 670                                contacts: vec![updated_contact.clone()],
 671                                remove_contacts: Default::default(),
 672                                incoming_requests: Default::default(),
 673                                remove_incoming_requests: Default::default(),
 674                                outgoing_requests: Default::default(),
 675                                remove_outgoing_requests: Default::default(),
 676                            },
 677                        )
 678                        .trace_err();
 679                }
 680            }
 681        }
 682        Ok(())
 683    }
 684
 685    async fn join_project(
 686        self: Arc<Server>,
 687        request: TypedEnvelope<proto::JoinProject>,
 688        response: Response<proto::JoinProject>,
 689    ) -> Result<()> {
 690        let project_id = ProjectId::from_proto(request.payload.project_id);
 691
 692        let host_user_id;
 693        let guest_user_id;
 694        let host_connection_id;
 695        {
 696            let state = self.store().await;
 697            let project = state.project(project_id)?;
 698            host_user_id = project.host.user_id;
 699            host_connection_id = project.host_connection_id;
 700            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 701        };
 702
 703        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 704        let has_contact = self
 705            .app_state
 706            .db
 707            .has_contact(guest_user_id, host_user_id)
 708            .await?;
 709        if !has_contact {
 710            return Err(anyhow!("no such project"))?;
 711        }
 712
 713        self.store_mut().await.request_join_project(
 714            guest_user_id,
 715            project_id,
 716            response.into_receipt(),
 717        )?;
 718        self.peer.send(
 719            host_connection_id,
 720            proto::RequestJoinProject {
 721                project_id: project_id.to_proto(),
 722                requester_id: guest_user_id.to_proto(),
 723            },
 724        )?;
 725        Ok(())
 726    }
 727
 728    async fn respond_to_join_project_request(
 729        self: Arc<Server>,
 730        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 731    ) -> Result<()> {
 732        let host_user_id;
 733
 734        {
 735            let mut state = self.store_mut().await;
 736            let project_id = ProjectId::from_proto(request.payload.project_id);
 737            let project = state.project(project_id)?;
 738            if project.host_connection_id != request.sender_id {
 739                Err(anyhow!("no such connection"))?;
 740            }
 741
 742            host_user_id = project.host.user_id;
 743            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 744
 745            if !request.payload.allow {
 746                let receipts = state
 747                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 748                    .ok_or_else(|| anyhow!("no such request"))?;
 749                for receipt in receipts {
 750                    self.peer.respond(
 751                        receipt,
 752                        proto::JoinProjectResponse {
 753                            variant: Some(proto::join_project_response::Variant::Decline(
 754                                proto::join_project_response::Decline {
 755                                    reason: proto::join_project_response::decline::Reason::Declined
 756                                        as i32,
 757                                },
 758                            )),
 759                        },
 760                    )?;
 761                }
 762                return Ok(());
 763            }
 764
 765            let (receipts_with_replica_ids, project) = state
 766                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 767                .ok_or_else(|| anyhow!("no such request"))?;
 768
 769            let peer_count = project.guests.len();
 770            let mut collaborators = Vec::with_capacity(peer_count);
 771            collaborators.push(proto::Collaborator {
 772                peer_id: project.host_connection_id.0,
 773                replica_id: 0,
 774                user_id: project.host.user_id.to_proto(),
 775            });
 776            let worktrees = project
 777                .worktrees
 778                .iter()
 779                .filter_map(|(id, shared_worktree)| {
 780                    let worktree = project.worktrees.get(&id)?;
 781                    Some(proto::Worktree {
 782                        id: *id,
 783                        root_name: worktree.root_name.clone(),
 784                        entries: shared_worktree.entries.values().cloned().collect(),
 785                        diagnostic_summaries: shared_worktree
 786                            .diagnostic_summaries
 787                            .values()
 788                            .cloned()
 789                            .collect(),
 790                        visible: worktree.visible,
 791                        scan_id: shared_worktree.scan_id,
 792                    })
 793                })
 794                .collect::<Vec<_>>();
 795
 796            // Add all guests other than the requesting user's own connections as collaborators
 797            for (guest_conn_id, guest) in &project.guests {
 798                if receipts_with_replica_ids
 799                    .iter()
 800                    .all(|(receipt, _)| receipt.sender_id != *guest_conn_id)
 801                {
 802                    collaborators.push(proto::Collaborator {
 803                        peer_id: guest_conn_id.0,
 804                        replica_id: guest.replica_id as u32,
 805                        user_id: guest.user_id.to_proto(),
 806                    });
 807                }
 808            }
 809
 810            for conn_id in project.connection_ids() {
 811                for (receipt, replica_id) in &receipts_with_replica_ids {
 812                    if conn_id != receipt.sender_id {
 813                        self.peer.send(
 814                            conn_id,
 815                            proto::AddProjectCollaborator {
 816                                project_id: project_id.to_proto(),
 817                                collaborator: Some(proto::Collaborator {
 818                                    peer_id: receipt.sender_id.0,
 819                                    replica_id: *replica_id as u32,
 820                                    user_id: guest_user_id.to_proto(),
 821                                }),
 822                            },
 823                        )?;
 824                    }
 825                }
 826            }
 827
 828            for (receipt, replica_id) in receipts_with_replica_ids {
 829                self.peer.respond(
 830                    receipt,
 831                    proto::JoinProjectResponse {
 832                        variant: Some(proto::join_project_response::Variant::Accept(
 833                            proto::join_project_response::Accept {
 834                                worktrees: worktrees.clone(),
 835                                replica_id: replica_id as u32,
 836                                collaborators: collaborators.clone(),
 837                                language_servers: project.language_servers.clone(),
 838                            },
 839                        )),
 840                    },
 841                )?;
 842            }
 843        }
 844
 845        self.update_user_contacts(host_user_id).await?;
 846        Ok(())
 847    }
 848
 849    async fn leave_project(
 850        self: Arc<Server>,
 851        request: TypedEnvelope<proto::LeaveProject>,
 852    ) -> Result<()> {
 853        let sender_id = request.sender_id;
 854        let project_id = ProjectId::from_proto(request.payload.project_id);
 855        let project;
 856        {
 857            let mut store = self.store_mut().await;
 858            project = store.leave_project(sender_id, project_id)?;
 859            tracing::info!(
 860                %project_id,
 861                host_user_id = %project.host_user_id,
 862                host_connection_id = %project.host_connection_id,
 863                "leave project"
 864            );
 865
 866            if project.remove_collaborator {
 867                broadcast(sender_id, project.connection_ids, |conn_id| {
 868                    self.peer.send(
 869                        conn_id,
 870                        proto::RemoveProjectCollaborator {
 871                            project_id: project_id.to_proto(),
 872                            peer_id: sender_id.0,
 873                        },
 874                    )
 875                });
 876            }
 877
 878            if let Some(requester_id) = project.cancel_request {
 879                self.peer.send(
 880                    project.host_connection_id,
 881                    proto::JoinProjectRequestCancelled {
 882                        project_id: project_id.to_proto(),
 883                        requester_id: requester_id.to_proto(),
 884                    },
 885                )?;
 886            }
 887
 888            if project.unshare {
 889                self.peer.send(
 890                    project.host_connection_id,
 891                    proto::ProjectUnshared {
 892                        project_id: project_id.to_proto(),
 893                    },
 894                )?;
 895            }
 896        }
 897        self.update_user_contacts(project.host_user_id).await?;
 898        Ok(())
 899    }
 900
 901    async fn update_project(
 902        self: Arc<Server>,
 903        request: TypedEnvelope<proto::UpdateProject>,
 904    ) -> Result<()> {
 905        let project_id = ProjectId::from_proto(request.payload.project_id);
 906        let user_id;
 907        {
 908            let mut state = self.store_mut().await;
 909            user_id = state.user_id_for_connection(request.sender_id)?;
 910            let guest_connection_ids = state
 911                .read_project(project_id, request.sender_id)?
 912                .guest_connection_ids();
 913            state.update_project(project_id, &request.payload.worktrees, request.sender_id)?;
 914            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 915                self.peer
 916                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 917            });
 918        };
 919        self.update_user_contacts(user_id).await?;
 920        Ok(())
 921    }
 922
 923    async fn register_project_activity(
 924        self: Arc<Server>,
 925        request: TypedEnvelope<proto::RegisterProjectActivity>,
 926    ) -> Result<()> {
 927        self.store_mut().await.register_project_activity(
 928            ProjectId::from_proto(request.payload.project_id),
 929            request.sender_id,
 930        )?;
 931        Ok(())
 932    }
 933
 934    async fn update_worktree(
 935        self: Arc<Server>,
 936        request: TypedEnvelope<proto::UpdateWorktree>,
 937        response: Response<proto::UpdateWorktree>,
 938    ) -> Result<()> {
 939        let project_id = ProjectId::from_proto(request.payload.project_id);
 940        let worktree_id = request.payload.worktree_id;
 941        let (connection_ids, metadata_changed, extension_counts) = {
 942            let mut store = self.store_mut().await;
 943            let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
 944                request.sender_id,
 945                project_id,
 946                worktree_id,
 947                &request.payload.root_name,
 948                &request.payload.removed_entries,
 949                &request.payload.updated_entries,
 950                request.payload.scan_id,
 951            )?;
 952            (connection_ids, metadata_changed, extension_counts.clone())
 953        };
 954        self.app_state
 955            .db
 956            .update_worktree_extensions(project_id, worktree_id, extension_counts)
 957            .await?;
 958
 959        broadcast(request.sender_id, connection_ids, |connection_id| {
 960            self.peer
 961                .forward_send(request.sender_id, connection_id, request.payload.clone())
 962        });
 963        if metadata_changed {
 964            let user_id = self
 965                .store()
 966                .await
 967                .user_id_for_connection(request.sender_id)?;
 968            self.update_user_contacts(user_id).await?;
 969        }
 970        response.send(proto::Ack {})?;
 971        Ok(())
 972    }
 973
 974    async fn update_diagnostic_summary(
 975        self: Arc<Server>,
 976        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 977    ) -> Result<()> {
 978        let summary = request
 979            .payload
 980            .summary
 981            .clone()
 982            .ok_or_else(|| anyhow!("invalid summary"))?;
 983        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 984            ProjectId::from_proto(request.payload.project_id),
 985            request.payload.worktree_id,
 986            request.sender_id,
 987            summary,
 988        )?;
 989
 990        broadcast(request.sender_id, receiver_ids, |connection_id| {
 991            self.peer
 992                .forward_send(request.sender_id, connection_id, request.payload.clone())
 993        });
 994        Ok(())
 995    }
 996
 997    async fn start_language_server(
 998        self: Arc<Server>,
 999        request: TypedEnvelope<proto::StartLanguageServer>,
1000    ) -> Result<()> {
1001        let receiver_ids = self.store_mut().await.start_language_server(
1002            ProjectId::from_proto(request.payload.project_id),
1003            request.sender_id,
1004            request
1005                .payload
1006                .server
1007                .clone()
1008                .ok_or_else(|| anyhow!("invalid language server"))?,
1009        )?;
1010        broadcast(request.sender_id, receiver_ids, |connection_id| {
1011            self.peer
1012                .forward_send(request.sender_id, connection_id, request.payload.clone())
1013        });
1014        Ok(())
1015    }
1016
1017    async fn update_language_server(
1018        self: Arc<Server>,
1019        request: TypedEnvelope<proto::UpdateLanguageServer>,
1020    ) -> Result<()> {
1021        let receiver_ids = self.store().await.project_connection_ids(
1022            ProjectId::from_proto(request.payload.project_id),
1023            request.sender_id,
1024        )?;
1025        broadcast(request.sender_id, receiver_ids, |connection_id| {
1026            self.peer
1027                .forward_send(request.sender_id, connection_id, request.payload.clone())
1028        });
1029        Ok(())
1030    }
1031
1032    async fn forward_project_request<T>(
1033        self: Arc<Server>,
1034        request: TypedEnvelope<T>,
1035        response: Response<T>,
1036    ) -> Result<()>
1037    where
1038        T: EntityMessage + RequestMessage,
1039    {
1040        let host_connection_id = self
1041            .store()
1042            .await
1043            .read_project(
1044                ProjectId::from_proto(request.payload.remote_entity_id()),
1045                request.sender_id,
1046            )?
1047            .host_connection_id;
1048
1049        response.send(
1050            self.peer
1051                .forward_request(request.sender_id, host_connection_id, request.payload)
1052                .await?,
1053        )?;
1054        Ok(())
1055    }
1056
1057    async fn save_buffer(
1058        self: Arc<Server>,
1059        request: TypedEnvelope<proto::SaveBuffer>,
1060        response: Response<proto::SaveBuffer>,
1061    ) -> Result<()> {
1062        let project_id = ProjectId::from_proto(request.payload.project_id);
1063        let host = self
1064            .store()
1065            .await
1066            .read_project(project_id, request.sender_id)?
1067            .host_connection_id;
1068        let response_payload = self
1069            .peer
1070            .forward_request(request.sender_id, host, request.payload.clone())
1071            .await?;
1072
1073        let mut guests = self
1074            .store()
1075            .await
1076            .read_project(project_id, request.sender_id)?
1077            .connection_ids();
1078        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1079        broadcast(host, guests, |conn_id| {
1080            self.peer
1081                .forward_send(host, conn_id, response_payload.clone())
1082        });
1083        response.send(response_payload)?;
1084        Ok(())
1085    }
1086
1087    async fn update_buffer(
1088        self: Arc<Server>,
1089        request: TypedEnvelope<proto::UpdateBuffer>,
1090        response: Response<proto::UpdateBuffer>,
1091    ) -> Result<()> {
1092        let project_id = ProjectId::from_proto(request.payload.project_id);
1093        let receiver_ids = {
1094            let mut store = self.store_mut().await;
1095            store.register_project_activity(project_id, request.sender_id)?;
1096            store.project_connection_ids(project_id, request.sender_id)?
1097        };
1098
1099        broadcast(request.sender_id, receiver_ids, |connection_id| {
1100            self.peer
1101                .forward_send(request.sender_id, connection_id, request.payload.clone())
1102        });
1103        response.send(proto::Ack {})?;
1104        Ok(())
1105    }
1106
1107    async fn update_buffer_file(
1108        self: Arc<Server>,
1109        request: TypedEnvelope<proto::UpdateBufferFile>,
1110    ) -> Result<()> {
1111        let receiver_ids = self.store().await.project_connection_ids(
1112            ProjectId::from_proto(request.payload.project_id),
1113            request.sender_id,
1114        )?;
1115        broadcast(request.sender_id, receiver_ids, |connection_id| {
1116            self.peer
1117                .forward_send(request.sender_id, connection_id, request.payload.clone())
1118        });
1119        Ok(())
1120    }
1121
1122    async fn buffer_reloaded(
1123        self: Arc<Server>,
1124        request: TypedEnvelope<proto::BufferReloaded>,
1125    ) -> Result<()> {
1126        let receiver_ids = self.store().await.project_connection_ids(
1127            ProjectId::from_proto(request.payload.project_id),
1128            request.sender_id,
1129        )?;
1130        broadcast(request.sender_id, receiver_ids, |connection_id| {
1131            self.peer
1132                .forward_send(request.sender_id, connection_id, request.payload.clone())
1133        });
1134        Ok(())
1135    }
1136
1137    async fn buffer_saved(
1138        self: Arc<Server>,
1139        request: TypedEnvelope<proto::BufferSaved>,
1140    ) -> Result<()> {
1141        let receiver_ids = self.store().await.project_connection_ids(
1142            ProjectId::from_proto(request.payload.project_id),
1143            request.sender_id,
1144        )?;
1145        broadcast(request.sender_id, receiver_ids, |connection_id| {
1146            self.peer
1147                .forward_send(request.sender_id, connection_id, request.payload.clone())
1148        });
1149        Ok(())
1150    }
1151
1152    async fn follow(
1153        self: Arc<Self>,
1154        request: TypedEnvelope<proto::Follow>,
1155        response: Response<proto::Follow>,
1156    ) -> Result<()> {
1157        let project_id = ProjectId::from_proto(request.payload.project_id);
1158        let leader_id = ConnectionId(request.payload.leader_id);
1159        let follower_id = request.sender_id;
1160        {
1161            let mut store = self.store_mut().await;
1162            if !store
1163                .project_connection_ids(project_id, follower_id)?
1164                .contains(&leader_id)
1165            {
1166                Err(anyhow!("no such peer"))?;
1167            }
1168
1169            store.register_project_activity(project_id, follower_id)?;
1170        }
1171
1172        let mut response_payload = self
1173            .peer
1174            .forward_request(request.sender_id, leader_id, request.payload)
1175            .await?;
1176        response_payload
1177            .views
1178            .retain(|view| view.leader_id != Some(follower_id.0));
1179        response.send(response_payload)?;
1180        Ok(())
1181    }
1182
1183    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1184        let project_id = ProjectId::from_proto(request.payload.project_id);
1185        let leader_id = ConnectionId(request.payload.leader_id);
1186        let mut store = self.store_mut().await;
1187        if !store
1188            .project_connection_ids(project_id, request.sender_id)?
1189            .contains(&leader_id)
1190        {
1191            Err(anyhow!("no such peer"))?;
1192        }
1193        store.register_project_activity(project_id, request.sender_id)?;
1194        self.peer
1195            .forward_send(request.sender_id, leader_id, request.payload)?;
1196        Ok(())
1197    }
1198
1199    async fn update_followers(
1200        self: Arc<Self>,
1201        request: TypedEnvelope<proto::UpdateFollowers>,
1202    ) -> Result<()> {
1203        let project_id = ProjectId::from_proto(request.payload.project_id);
1204        let mut store = self.store_mut().await;
1205        store.register_project_activity(project_id, request.sender_id)?;
1206        let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
1207        let leader_id = request
1208            .payload
1209            .variant
1210            .as_ref()
1211            .and_then(|variant| match variant {
1212                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1213                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1214                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1215            });
1216        for follower_id in &request.payload.follower_ids {
1217            let follower_id = ConnectionId(*follower_id);
1218            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1219                self.peer
1220                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1221            }
1222        }
1223        Ok(())
1224    }
1225
1226    async fn get_channels(
1227        self: Arc<Server>,
1228        request: TypedEnvelope<proto::GetChannels>,
1229        response: Response<proto::GetChannels>,
1230    ) -> Result<()> {
1231        let user_id = self
1232            .store()
1233            .await
1234            .user_id_for_connection(request.sender_id)?;
1235        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1236        response.send(proto::GetChannelsResponse {
1237            channels: channels
1238                .into_iter()
1239                .map(|chan| proto::Channel {
1240                    id: chan.id.to_proto(),
1241                    name: chan.name,
1242                })
1243                .collect(),
1244        })?;
1245        Ok(())
1246    }
1247
1248    async fn get_users(
1249        self: Arc<Server>,
1250        request: TypedEnvelope<proto::GetUsers>,
1251        response: Response<proto::GetUsers>,
1252    ) -> Result<()> {
1253        let user_ids = request
1254            .payload
1255            .user_ids
1256            .into_iter()
1257            .map(UserId::from_proto)
1258            .collect();
1259        let users = self
1260            .app_state
1261            .db
1262            .get_users_by_ids(user_ids)
1263            .await?
1264            .into_iter()
1265            .map(|user| proto::User {
1266                id: user.id.to_proto(),
1267                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1268                github_login: user.github_login,
1269            })
1270            .collect();
1271        response.send(proto::UsersResponse { users })?;
1272        Ok(())
1273    }
1274
1275    async fn fuzzy_search_users(
1276        self: Arc<Server>,
1277        request: TypedEnvelope<proto::FuzzySearchUsers>,
1278        response: Response<proto::FuzzySearchUsers>,
1279    ) -> Result<()> {
1280        let user_id = self
1281            .store()
1282            .await
1283            .user_id_for_connection(request.sender_id)?;
1284        let query = request.payload.query;
1285        let db = &self.app_state.db;
1286        let users = match query.len() {
1287            0 => vec![],
1288            1 | 2 => db
1289                .get_user_by_github_login(&query)
1290                .await?
1291                .into_iter()
1292                .collect(),
1293            _ => db.fuzzy_search_users(&query, 10).await?,
1294        };
1295        let users = users
1296            .into_iter()
1297            .filter(|user| user.id != user_id)
1298            .map(|user| proto::User {
1299                id: user.id.to_proto(),
1300                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1301                github_login: user.github_login,
1302            })
1303            .collect();
1304        response.send(proto::UsersResponse { users })?;
1305        Ok(())
1306    }
1307
1308    async fn request_contact(
1309        self: Arc<Server>,
1310        request: TypedEnvelope<proto::RequestContact>,
1311        response: Response<proto::RequestContact>,
1312    ) -> Result<()> {
1313        let requester_id = self
1314            .store()
1315            .await
1316            .user_id_for_connection(request.sender_id)?;
1317        let responder_id = UserId::from_proto(request.payload.responder_id);
1318        if requester_id == responder_id {
1319            return Err(anyhow!("cannot add yourself as a contact"))?;
1320        }
1321
1322        self.app_state
1323            .db
1324            .send_contact_request(requester_id, responder_id)
1325            .await?;
1326
1327        // Update outgoing contact requests of requester
1328        let mut update = proto::UpdateContacts::default();
1329        update.outgoing_requests.push(responder_id.to_proto());
1330        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1331            self.peer.send(connection_id, update.clone())?;
1332        }
1333
1334        // Update incoming contact requests of responder
1335        let mut update = proto::UpdateContacts::default();
1336        update
1337            .incoming_requests
1338            .push(proto::IncomingContactRequest {
1339                requester_id: requester_id.to_proto(),
1340                should_notify: true,
1341            });
1342        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1343            self.peer.send(connection_id, update.clone())?;
1344        }
1345
1346        response.send(proto::Ack {})?;
1347        Ok(())
1348    }
1349
1350    async fn respond_to_contact_request(
1351        self: Arc<Server>,
1352        request: TypedEnvelope<proto::RespondToContactRequest>,
1353        response: Response<proto::RespondToContactRequest>,
1354    ) -> Result<()> {
1355        let responder_id = self
1356            .store()
1357            .await
1358            .user_id_for_connection(request.sender_id)?;
1359        let requester_id = UserId::from_proto(request.payload.requester_id);
1360        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1361            self.app_state
1362                .db
1363                .dismiss_contact_notification(responder_id, requester_id)
1364                .await?;
1365        } else {
1366            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1367            self.app_state
1368                .db
1369                .respond_to_contact_request(responder_id, requester_id, accept)
1370                .await?;
1371
1372            let store = self.store().await;
1373            // Update responder with new contact
1374            let mut update = proto::UpdateContacts::default();
1375            if accept {
1376                update
1377                    .contacts
1378                    .push(store.contact_for_user(requester_id, false));
1379            }
1380            update
1381                .remove_incoming_requests
1382                .push(requester_id.to_proto());
1383            for connection_id in store.connection_ids_for_user(responder_id) {
1384                self.peer.send(connection_id, update.clone())?;
1385            }
1386
1387            // Update requester with new contact
1388            let mut update = proto::UpdateContacts::default();
1389            if accept {
1390                update
1391                    .contacts
1392                    .push(store.contact_for_user(responder_id, true));
1393            }
1394            update
1395                .remove_outgoing_requests
1396                .push(responder_id.to_proto());
1397            for connection_id in store.connection_ids_for_user(requester_id) {
1398                self.peer.send(connection_id, update.clone())?;
1399            }
1400        }
1401
1402        response.send(proto::Ack {})?;
1403        Ok(())
1404    }
1405
1406    async fn remove_contact(
1407        self: Arc<Server>,
1408        request: TypedEnvelope<proto::RemoveContact>,
1409        response: Response<proto::RemoveContact>,
1410    ) -> Result<()> {
1411        let requester_id = self
1412            .store()
1413            .await
1414            .user_id_for_connection(request.sender_id)?;
1415        let responder_id = UserId::from_proto(request.payload.user_id);
1416        self.app_state
1417            .db
1418            .remove_contact(requester_id, responder_id)
1419            .await?;
1420
1421        // Update outgoing contact requests of requester
1422        let mut update = proto::UpdateContacts::default();
1423        update
1424            .remove_outgoing_requests
1425            .push(responder_id.to_proto());
1426        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1427            self.peer.send(connection_id, update.clone())?;
1428        }
1429
1430        // Update incoming contact requests of responder
1431        let mut update = proto::UpdateContacts::default();
1432        update
1433            .remove_incoming_requests
1434            .push(requester_id.to_proto());
1435        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1436            self.peer.send(connection_id, update.clone())?;
1437        }
1438
1439        response.send(proto::Ack {})?;
1440        Ok(())
1441    }
1442
1443    async fn join_channel(
1444        self: Arc<Self>,
1445        request: TypedEnvelope<proto::JoinChannel>,
1446        response: Response<proto::JoinChannel>,
1447    ) -> Result<()> {
1448        let user_id = self
1449            .store()
1450            .await
1451            .user_id_for_connection(request.sender_id)?;
1452        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1453        if !self
1454            .app_state
1455            .db
1456            .can_user_access_channel(user_id, channel_id)
1457            .await?
1458        {
1459            Err(anyhow!("access denied"))?;
1460        }
1461
1462        self.store_mut()
1463            .await
1464            .join_channel(request.sender_id, channel_id);
1465        let messages = self
1466            .app_state
1467            .db
1468            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1469            .await?
1470            .into_iter()
1471            .map(|msg| proto::ChannelMessage {
1472                id: msg.id.to_proto(),
1473                body: msg.body,
1474                timestamp: msg.sent_at.unix_timestamp() as u64,
1475                sender_id: msg.sender_id.to_proto(),
1476                nonce: Some(msg.nonce.as_u128().into()),
1477            })
1478            .collect::<Vec<_>>();
1479        response.send(proto::JoinChannelResponse {
1480            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1481            messages,
1482        })?;
1483        Ok(())
1484    }
1485
1486    async fn leave_channel(
1487        self: Arc<Self>,
1488        request: TypedEnvelope<proto::LeaveChannel>,
1489    ) -> Result<()> {
1490        let user_id = self
1491            .store()
1492            .await
1493            .user_id_for_connection(request.sender_id)?;
1494        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1495        if !self
1496            .app_state
1497            .db
1498            .can_user_access_channel(user_id, channel_id)
1499            .await?
1500        {
1501            Err(anyhow!("access denied"))?;
1502        }
1503
1504        self.store_mut()
1505            .await
1506            .leave_channel(request.sender_id, channel_id);
1507
1508        Ok(())
1509    }
1510
1511    async fn send_channel_message(
1512        self: Arc<Self>,
1513        request: TypedEnvelope<proto::SendChannelMessage>,
1514        response: Response<proto::SendChannelMessage>,
1515    ) -> Result<()> {
1516        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1517        let user_id;
1518        let connection_ids;
1519        {
1520            let state = self.store().await;
1521            user_id = state.user_id_for_connection(request.sender_id)?;
1522            connection_ids = state.channel_connection_ids(channel_id)?;
1523        }
1524
1525        // Validate the message body.
1526        let body = request.payload.body.trim().to_string();
1527        if body.len() > MAX_MESSAGE_LEN {
1528            return Err(anyhow!("message is too long"))?;
1529        }
1530        if body.is_empty() {
1531            return Err(anyhow!("message can't be blank"))?;
1532        }
1533
1534        let timestamp = OffsetDateTime::now_utc();
1535        let nonce = request
1536            .payload
1537            .nonce
1538            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1539
1540        let message_id = self
1541            .app_state
1542            .db
1543            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1544            .await?
1545            .to_proto();
1546        let message = proto::ChannelMessage {
1547            sender_id: user_id.to_proto(),
1548            id: message_id,
1549            body,
1550            timestamp: timestamp.unix_timestamp() as u64,
1551            nonce: Some(nonce),
1552        };
1553        broadcast(request.sender_id, connection_ids, |conn_id| {
1554            self.peer.send(
1555                conn_id,
1556                proto::ChannelMessageSent {
1557                    channel_id: channel_id.to_proto(),
1558                    message: Some(message.clone()),
1559                },
1560            )
1561        });
1562        response.send(proto::SendChannelMessageResponse {
1563            message: Some(message),
1564        })?;
1565        Ok(())
1566    }
1567
1568    async fn get_channel_messages(
1569        self: Arc<Self>,
1570        request: TypedEnvelope<proto::GetChannelMessages>,
1571        response: Response<proto::GetChannelMessages>,
1572    ) -> Result<()> {
1573        let user_id = self
1574            .store()
1575            .await
1576            .user_id_for_connection(request.sender_id)?;
1577        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1578        if !self
1579            .app_state
1580            .db
1581            .can_user_access_channel(user_id, channel_id)
1582            .await?
1583        {
1584            Err(anyhow!("access denied"))?;
1585        }
1586
1587        let messages = self
1588            .app_state
1589            .db
1590            .get_channel_messages(
1591                channel_id,
1592                MESSAGE_COUNT_PER_PAGE,
1593                Some(MessageId::from_proto(request.payload.before_message_id)),
1594            )
1595            .await?
1596            .into_iter()
1597            .map(|msg| proto::ChannelMessage {
1598                id: msg.id.to_proto(),
1599                body: msg.body,
1600                timestamp: msg.sent_at.unix_timestamp() as u64,
1601                sender_id: msg.sender_id.to_proto(),
1602                nonce: Some(msg.nonce.as_u128().into()),
1603            })
1604            .collect::<Vec<_>>();
1605        response.send(proto::GetChannelMessagesResponse {
1606            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1607            messages,
1608        })?;
1609        Ok(())
1610    }
1611
1612    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1613        #[cfg(test)]
1614        tokio::task::yield_now().await;
1615        let guard = self.store.read().await;
1616        #[cfg(test)]
1617        tokio::task::yield_now().await;
1618        StoreReadGuard {
1619            guard,
1620            _not_send: PhantomData,
1621        }
1622    }
1623
1624    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1625        #[cfg(test)]
1626        tokio::task::yield_now().await;
1627        let guard = self.store.write().await;
1628        #[cfg(test)]
1629        tokio::task::yield_now().await;
1630        StoreWriteGuard {
1631            guard,
1632            _not_send: PhantomData,
1633        }
1634    }
1635
1636    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1637        ServerSnapshot {
1638            store: self.store.read().await,
1639            peer: &self.peer,
1640        }
1641    }
1642}
1643
1644impl<'a> Deref for StoreReadGuard<'a> {
1645    type Target = Store;
1646
1647    fn deref(&self) -> &Self::Target {
1648        &*self.guard
1649    }
1650}
1651
1652impl<'a> Deref for StoreWriteGuard<'a> {
1653    type Target = Store;
1654
1655    fn deref(&self) -> &Self::Target {
1656        &*self.guard
1657    }
1658}
1659
1660impl<'a> DerefMut for StoreWriteGuard<'a> {
1661    fn deref_mut(&mut self) -> &mut Self::Target {
1662        &mut *self.guard
1663    }
1664}
1665
1666impl<'a> Drop for StoreWriteGuard<'a> {
1667    fn drop(&mut self) {
1668        #[cfg(test)]
1669        self.check_invariants();
1670
1671        let metrics = self.metrics();
1672
1673        METRIC_CONNECTIONS.set(metrics.connections as _);
1674        METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1675        METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1676        METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1677
1678        tracing::info!(
1679            connections = metrics.connections,
1680            registered_projects = metrics.registered_projects,
1681            active_projects = metrics.active_projects,
1682            shared_projects = metrics.shared_projects,
1683            "metrics"
1684        );
1685    }
1686}
1687
1688impl Executor for RealExecutor {
1689    type Sleep = Sleep;
1690
1691    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1692        tokio::task::spawn(future);
1693    }
1694
1695    fn sleep(&self, duration: Duration) -> Self::Sleep {
1696        tokio::time::sleep(duration)
1697    }
1698}
1699
1700fn broadcast<F>(
1701    sender_id: ConnectionId,
1702    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1703    mut f: F,
1704) where
1705    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1706{
1707    for receiver_id in receiver_ids {
1708        if receiver_id != sender_id {
1709            f(receiver_id).trace_err();
1710        }
1711    }
1712}
1713
1714lazy_static! {
1715    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1716}
1717
1718pub struct ProtocolVersion(u32);
1719
1720impl Header for ProtocolVersion {
1721    fn name() -> &'static HeaderName {
1722        &ZED_PROTOCOL_VERSION
1723    }
1724
1725    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1726    where
1727        Self: Sized,
1728        I: Iterator<Item = &'i axum::http::HeaderValue>,
1729    {
1730        let version = values
1731            .next()
1732            .ok_or_else(|| axum::headers::Error::invalid())?
1733            .to_str()
1734            .map_err(|_| axum::headers::Error::invalid())?
1735            .parse()
1736            .map_err(|_| axum::headers::Error::invalid())?;
1737        Ok(Self(version))
1738    }
1739
1740    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1741        values.extend([self.0.to_string().parse().unwrap()]);
1742    }
1743}
1744
1745pub fn routes(server: Arc<Server>) -> Router<Body> {
1746    Router::new()
1747        .route("/rpc", get(handle_websocket_request))
1748        .layer(
1749            ServiceBuilder::new()
1750                .layer(Extension(server.app_state.clone()))
1751                .layer(middleware::from_fn(auth::validate_header)),
1752        )
1753        .route("/metrics", get(handle_metrics))
1754        .layer(Extension(server))
1755}
1756
1757pub async fn handle_websocket_request(
1758    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1759    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1760    Extension(server): Extension<Arc<Server>>,
1761    Extension(user): Extension<User>,
1762    ws: WebSocketUpgrade,
1763) -> axum::response::Response {
1764    if protocol_version != rpc::PROTOCOL_VERSION {
1765        return (
1766            StatusCode::UPGRADE_REQUIRED,
1767            "client must be upgraded".to_string(),
1768        )
1769            .into_response();
1770    }
1771    let socket_address = socket_address.to_string();
1772    ws.on_upgrade(move |socket| {
1773        use util::ResultExt;
1774        let socket = socket
1775            .map_ok(to_tungstenite_message)
1776            .err_into()
1777            .with(|message| async move { Ok(to_axum_message(message)) });
1778        let connection = Connection::new(Box::pin(socket));
1779        async move {
1780            server
1781                .handle_connection(connection, socket_address, user, None, RealExecutor)
1782                .await
1783                .log_err();
1784        }
1785    })
1786}
1787
1788pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1789    // We call `store_mut` here for its side effects of updating metrics.
1790    server.store_mut().await;
1791
1792    let encoder = prometheus::TextEncoder::new();
1793    let metric_families = prometheus::gather();
1794    match encoder.encode_to_string(&metric_families) {
1795        Ok(string) => (StatusCode::OK, string).into_response(),
1796        Err(error) => (
1797            StatusCode::INTERNAL_SERVER_ERROR,
1798            format!("failed to encode metrics {:?}", error),
1799        )
1800            .into_response(),
1801    }
1802}
1803
1804fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1805    match message {
1806        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1807        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1808        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1809        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1810        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1811            code: frame.code.into(),
1812            reason: frame.reason,
1813        })),
1814    }
1815}
1816
1817fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1818    match message {
1819        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1820        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1821        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1822        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1823        AxumMessage::Close(frame) => {
1824            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1825                code: frame.code.into(),
1826                reason: frame.reason,
1827            }))
1828        }
1829    }
1830}
1831
1832pub trait ResultExt {
1833    type Ok;
1834
1835    fn trace_err(self) -> Option<Self::Ok>;
1836}
1837
1838impl<T, E> ResultExt for Result<T, E>
1839where
1840    E: std::fmt::Debug,
1841{
1842    type Ok = T;
1843
1844    fn trace_err(self) -> Option<T> {
1845        match self {
1846            Ok(value) => Some(value),
1847            Err(error) => {
1848                tracing::error!("{:?}", error);
1849                None
1850            }
1851        }
1852    }
1853}