rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use prometheus::{register_int_gauge, IntGauge};
  33use rpc::{
  34    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  35    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  36};
  37use serde::{Serialize, Serializer};
  38use std::{
  39    any::TypeId,
  40    future::Future,
  41    marker::PhantomData,
  42    net::SocketAddr,
  43    ops::{Deref, DerefMut},
  44    rc::Rc,
  45    sync::{
  46        atomic::{AtomicBool, Ordering::SeqCst},
  47        Arc,
  48    },
  49    time::Duration,
  50};
  51use time::OffsetDateTime;
  52use tokio::{
  53    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  54    time::Sleep,
  55};
  56use tower::ServiceBuilder;
  57use tracing::{info_span, instrument, Instrument};
  58
  59pub use store::{Store, Worktree};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_REGISTERED_PROJECTS: IntGauge =
  65        register_int_gauge!("registered_projects", "number of registered projects").unwrap();
  66    static ref METRIC_ACTIVE_PROJECTS: IntGauge =
  67        register_int_gauge!("active_projects", "number of active projects").unwrap();
  68    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  69        "shared_projects",
  70        "number of open projects with one or more guests"
  71    )
  72    .unwrap();
  73}
  74
  75type MessageHandler =
  76    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  77
  78struct Response<R> {
  79    server: Arc<Server>,
  80    receipt: Receipt<R>,
  81    responded: Arc<AtomicBool>,
  82}
  83
  84impl<R: RequestMessage> Response<R> {
  85    fn send(self, payload: R::Response) -> Result<()> {
  86        self.responded.store(true, SeqCst);
  87        self.server.peer.respond(self.receipt, payload)?;
  88        Ok(())
  89    }
  90
  91    fn into_receipt(self) -> Receipt<R> {
  92        self.responded.store(true, SeqCst);
  93        self.receipt
  94    }
  95}
  96
  97pub struct Server {
  98    peer: Arc<Peer>,
  99    pub(crate) store: RwLock<Store>,
 100    app_state: Arc<AppState>,
 101    handlers: HashMap<TypeId, MessageHandler>,
 102    notifications: Option<mpsc::UnboundedSender<()>>,
 103}
 104
 105pub trait Executor: Send + Clone {
 106    type Sleep: Send + Future;
 107    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 108    fn sleep(&self, duration: Duration) -> Self::Sleep;
 109}
 110
 111#[derive(Clone)]
 112pub struct RealExecutor;
 113
 114const MESSAGE_COUNT_PER_PAGE: usize = 100;
 115const MAX_MESSAGE_LEN: usize = 1024;
 116
 117struct StoreReadGuard<'a> {
 118    guard: RwLockReadGuard<'a, Store>,
 119    _not_send: PhantomData<Rc<()>>,
 120}
 121
 122struct StoreWriteGuard<'a> {
 123    guard: RwLockWriteGuard<'a, Store>,
 124    _not_send: PhantomData<Rc<()>>,
 125}
 126
 127#[derive(Serialize)]
 128pub struct ServerSnapshot<'a> {
 129    peer: &'a Peer,
 130    #[serde(serialize_with = "serialize_deref")]
 131    store: RwLockReadGuard<'a, Store>,
 132}
 133
 134pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 135where
 136    S: Serializer,
 137    T: Deref<Target = U>,
 138    U: Serialize,
 139{
 140    Serialize::serialize(value.deref(), serializer)
 141}
 142
 143impl Server {
 144    pub fn new(
 145        app_state: Arc<AppState>,
 146        notifications: Option<mpsc::UnboundedSender<()>>,
 147    ) -> Arc<Self> {
 148        let mut server = Self {
 149            peer: Peer::new(),
 150            app_state,
 151            store: Default::default(),
 152            handlers: Default::default(),
 153            notifications,
 154        };
 155
 156        server
 157            .add_request_handler(Server::ping)
 158            .add_request_handler(Server::register_project)
 159            .add_request_handler(Server::unregister_project)
 160            .add_request_handler(Server::join_project)
 161            .add_message_handler(Server::leave_project)
 162            .add_message_handler(Server::respond_to_join_project_request)
 163            .add_message_handler(Server::update_project)
 164            .add_message_handler(Server::register_project_activity)
 165            .add_request_handler(Server::update_worktree)
 166            .add_message_handler(Server::start_language_server)
 167            .add_message_handler(Server::update_language_server)
 168            .add_message_handler(Server::update_diagnostic_summary)
 169            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 170            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 172            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 173            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 174            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 175            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 176            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 177            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 178            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 179            .add_request_handler(
 180                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 181            )
 182            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 183            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 184            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 185            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 186            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 187            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 188            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 189            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 190            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 191            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 192            .add_request_handler(Server::update_buffer)
 193            .add_message_handler(Server::update_buffer_file)
 194            .add_message_handler(Server::buffer_reloaded)
 195            .add_message_handler(Server::buffer_saved)
 196            .add_request_handler(Server::save_buffer)
 197            .add_request_handler(Server::get_channels)
 198            .add_request_handler(Server::get_users)
 199            .add_request_handler(Server::fuzzy_search_users)
 200            .add_request_handler(Server::request_contact)
 201            .add_request_handler(Server::remove_contact)
 202            .add_request_handler(Server::respond_to_contact_request)
 203            .add_request_handler(Server::join_channel)
 204            .add_message_handler(Server::leave_channel)
 205            .add_request_handler(Server::send_channel_message)
 206            .add_request_handler(Server::follow)
 207            .add_message_handler(Server::unfollow)
 208            .add_message_handler(Server::update_followers)
 209            .add_request_handler(Server::get_channel_messages);
 210
 211        Arc::new(server)
 212    }
 213
 214    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 215    where
 216        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 217        Fut: 'static + Send + Future<Output = Result<()>>,
 218        M: EnvelopedMessage,
 219    {
 220        let prev_handler = self.handlers.insert(
 221            TypeId::of::<M>(),
 222            Box::new(move |server, envelope| {
 223                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 224                let span = info_span!(
 225                    "handle message",
 226                    payload_type = envelope.payload_type_name()
 227                );
 228                span.in_scope(|| {
 229                    tracing::info!(
 230                        payload = format!("{:?}", envelope.payload).as_str(),
 231                        "message payload"
 232                    );
 233                });
 234                let future = (handler)(server, *envelope);
 235                async move {
 236                    if let Err(error) = future.await {
 237                        tracing::error!(%error, "error handling message");
 238                    }
 239                }
 240                .instrument(span)
 241                .boxed()
 242            }),
 243        );
 244        if prev_handler.is_some() {
 245            panic!("registered a handler for the same message twice");
 246        }
 247        self
 248    }
 249
 250    /// Handle a request while holding a lock to the store. This is useful when we're registering
 251    /// a connection but we want to respond on the connection before anybody else can send on it.
 252    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 253    where
 254        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 255        Fut: Send + Future<Output = Result<()>>,
 256        M: RequestMessage,
 257    {
 258        let handler = Arc::new(handler);
 259        self.add_message_handler(move |server, envelope| {
 260            let receipt = envelope.receipt();
 261            let handler = handler.clone();
 262            async move {
 263                let responded = Arc::new(AtomicBool::default());
 264                let response = Response {
 265                    server: server.clone(),
 266                    responded: responded.clone(),
 267                    receipt: envelope.receipt(),
 268                };
 269                match (handler)(server.clone(), envelope, response).await {
 270                    Ok(()) => {
 271                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 272                            Ok(())
 273                        } else {
 274                            Err(anyhow!("handler did not send a response"))?
 275                        }
 276                    }
 277                    Err(error) => {
 278                        server.peer.respond_with_error(
 279                            receipt,
 280                            proto::Error {
 281                                message: error.to_string(),
 282                            },
 283                        )?;
 284                        Err(error)
 285                    }
 286                }
 287            }
 288        })
 289    }
 290
 291    pub fn handle_connection<E: Executor>(
 292        self: &Arc<Self>,
 293        connection: Connection,
 294        address: String,
 295        user: User,
 296        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 297        executor: E,
 298    ) -> impl Future<Output = Result<()>> {
 299        let mut this = self.clone();
 300        let user_id = user.id;
 301        let login = user.github_login;
 302        let span = info_span!("handle connection", %user_id, %login, %address);
 303        async move {
 304            let (connection_id, handle_io, mut incoming_rx) = this
 305                .peer
 306                .add_connection(connection, {
 307                    let executor = executor.clone();
 308                    move |duration| {
 309                        let timer = executor.sleep(duration);
 310                        async move {
 311                            timer.await;
 312                        }
 313                    }
 314                })
 315                .await;
 316
 317            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 318
 319            if let Some(send_connection_id) = send_connection_id.as_mut() {
 320                let _ = send_connection_id.send(connection_id).await;
 321            }
 322
 323            if !user.connected_once {
 324                this.peer.send(connection_id, proto::ShowContacts {})?;
 325                this.app_state.db.set_user_connected_once(user_id, true).await?;
 326            }
 327
 328            let (contacts, invite_code) = future::try_join(
 329                this.app_state.db.get_contacts(user_id),
 330                this.app_state.db.get_invite_code_for_user(user_id)
 331            ).await?;
 332
 333            {
 334                let mut store = this.store_mut().await;
 335                store.add_connection(connection_id, user_id, user.admin);
 336                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 337
 338                if let Some((code, count)) = invite_code {
 339                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 340                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 341                        count,
 342                    })?;
 343                }
 344            }
 345            this.update_user_contacts(user_id).await?;
 346
 347            let handle_io = handle_io.fuse();
 348            futures::pin_mut!(handle_io);
 349            loop {
 350                let next_message = incoming_rx.next().fuse();
 351                futures::pin_mut!(next_message);
 352                futures::select_biased! {
 353                    result = handle_io => {
 354                        if let Err(error) = result {
 355                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 356                        }
 357                        break;
 358                    }
 359                    message = next_message => {
 360                        if let Some(message) = message {
 361                            let type_name = message.payload_type_name();
 362                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 363                            async {
 364                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 365                                    let notifications = this.notifications.clone();
 366                                    let is_background = message.is_background();
 367                                    let handle_message = (handler)(this.clone(), message);
 368                                    let handle_message = async move {
 369                                        handle_message.await;
 370                                        if let Some(mut notifications) = notifications {
 371                                            let _ = notifications.send(()).await;
 372                                        }
 373                                    };
 374                                    if is_background {
 375                                        executor.spawn_detached(handle_message);
 376                                    } else {
 377                                        handle_message.await;
 378                                    }
 379                                } else {
 380                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 381                                }
 382                            }.instrument(span).await;
 383                        } else {
 384                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 385                            break;
 386                        }
 387                    }
 388                }
 389            }
 390
 391            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 392            if let Err(error) = this.sign_out(connection_id).await {
 393                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 394            }
 395
 396            Ok(())
 397        }.instrument(span)
 398    }
 399
 400    #[instrument(skip(self), err)]
 401    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 402        self.peer.disconnect(connection_id);
 403
 404        let removed_user_id = {
 405            let mut store = self.store_mut().await;
 406            let removed_connection = store.remove_connection(connection_id)?;
 407
 408            for (project_id, project) in removed_connection.hosted_projects {
 409                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 410                    self.peer
 411                        .send(conn_id, proto::UnregisterProject { project_id })
 412                });
 413
 414                for (_, receipts) in project.join_requests {
 415                    for receipt in receipts {
 416                        self.peer.respond(
 417                            receipt,
 418                            proto::JoinProjectResponse {
 419                                variant: Some(proto::join_project_response::Variant::Decline(
 420                                    proto::join_project_response::Decline {
 421                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 422                                    },
 423                                )),
 424                            },
 425                        )?;
 426                    }
 427                }
 428            }
 429
 430            for project_id in removed_connection.guest_project_ids {
 431                if let Some(project) = store.project(project_id).trace_err() {
 432                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 433                        self.peer.send(
 434                            conn_id,
 435                            proto::RemoveProjectCollaborator {
 436                                project_id,
 437                                peer_id: connection_id.0,
 438                            },
 439                        )
 440                    });
 441                    if project.guests.is_empty() {
 442                        self.peer
 443                            .send(
 444                                project.host_connection_id,
 445                                proto::ProjectUnshared { project_id },
 446                            )
 447                            .trace_err();
 448                    }
 449                }
 450            }
 451
 452            removed_connection.user_id
 453        };
 454
 455        self.update_user_contacts(removed_user_id).await?;
 456
 457        Ok(())
 458    }
 459
 460    pub async fn invite_code_redeemed(
 461        self: &Arc<Self>,
 462        code: &str,
 463        invitee_id: UserId,
 464    ) -> Result<()> {
 465        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 466        let store = self.store().await;
 467        let invitee_contact = store.contact_for_user(invitee_id, true);
 468        for connection_id in store.connection_ids_for_user(user.id) {
 469            self.peer.send(
 470                connection_id,
 471                proto::UpdateContacts {
 472                    contacts: vec![invitee_contact.clone()],
 473                    ..Default::default()
 474                },
 475            )?;
 476            self.peer.send(
 477                connection_id,
 478                proto::UpdateInviteInfo {
 479                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 480                    count: user.invite_count as u32,
 481                },
 482            )?;
 483        }
 484        Ok(())
 485    }
 486
 487    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 488        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 489            if let Some(invite_code) = &user.invite_code {
 490                let store = self.store().await;
 491                for connection_id in store.connection_ids_for_user(user_id) {
 492                    self.peer.send(
 493                        connection_id,
 494                        proto::UpdateInviteInfo {
 495                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 496                            count: user.invite_count as u32,
 497                        },
 498                    )?;
 499                }
 500            }
 501        }
 502        Ok(())
 503    }
 504
 505    async fn ping(
 506        self: Arc<Server>,
 507        _: TypedEnvelope<proto::Ping>,
 508        response: Response<proto::Ping>,
 509    ) -> Result<()> {
 510        response.send(proto::Ack {})?;
 511        Ok(())
 512    }
 513
 514    async fn register_project(
 515        self: Arc<Server>,
 516        request: TypedEnvelope<proto::RegisterProject>,
 517        response: Response<proto::RegisterProject>,
 518    ) -> Result<()> {
 519        let project_id;
 520        {
 521            let mut state = self.store_mut().await;
 522            let user_id = state.user_id_for_connection(request.sender_id)?;
 523            project_id = state.register_project(request.sender_id, user_id);
 524        };
 525
 526        response.send(proto::RegisterProjectResponse { project_id })?;
 527
 528        Ok(())
 529    }
 530
 531    async fn unregister_project(
 532        self: Arc<Server>,
 533        request: TypedEnvelope<proto::UnregisterProject>,
 534        response: Response<proto::UnregisterProject>,
 535    ) -> Result<()> {
 536        let (user_id, project) = {
 537            let mut state = self.store_mut().await;
 538            let project =
 539                state.unregister_project(request.payload.project_id, request.sender_id)?;
 540            (state.user_id_for_connection(request.sender_id)?, project)
 541        };
 542
 543        broadcast(
 544            request.sender_id,
 545            project.guests.keys().copied(),
 546            |conn_id| {
 547                self.peer.send(
 548                    conn_id,
 549                    proto::UnregisterProject {
 550                        project_id: request.payload.project_id,
 551                    },
 552                )
 553            },
 554        );
 555        for (_, receipts) in project.join_requests {
 556            for receipt in receipts {
 557                self.peer.respond(
 558                    receipt,
 559                    proto::JoinProjectResponse {
 560                        variant: Some(proto::join_project_response::Variant::Decline(
 561                            proto::join_project_response::Decline {
 562                                reason: proto::join_project_response::decline::Reason::Closed
 563                                    as i32,
 564                            },
 565                        )),
 566                    },
 567                )?;
 568            }
 569        }
 570
 571        // Send out the `UpdateContacts` message before responding to the unregister
 572        // request. This way, when the project's host can keep track of the project's
 573        // remote id until after they've received the `UpdateContacts` message for
 574        // themself.
 575        self.update_user_contacts(user_id).await?;
 576        response.send(proto::Ack {})?;
 577
 578        Ok(())
 579    }
 580
 581    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 582        let contacts = self.app_state.db.get_contacts(user_id).await?;
 583        let store = self.store().await;
 584        let updated_contact = store.contact_for_user(user_id, false);
 585        for contact in contacts {
 586            if let db::Contact::Accepted {
 587                user_id: contact_user_id,
 588                ..
 589            } = contact
 590            {
 591                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 592                    self.peer
 593                        .send(
 594                            contact_conn_id,
 595                            proto::UpdateContacts {
 596                                contacts: vec![updated_contact.clone()],
 597                                remove_contacts: Default::default(),
 598                                incoming_requests: Default::default(),
 599                                remove_incoming_requests: Default::default(),
 600                                outgoing_requests: Default::default(),
 601                                remove_outgoing_requests: Default::default(),
 602                            },
 603                        )
 604                        .trace_err();
 605                }
 606            }
 607        }
 608        Ok(())
 609    }
 610
 611    async fn join_project(
 612        self: Arc<Server>,
 613        request: TypedEnvelope<proto::JoinProject>,
 614        response: Response<proto::JoinProject>,
 615    ) -> Result<()> {
 616        let project_id = request.payload.project_id;
 617
 618        let host_user_id;
 619        let guest_user_id;
 620        let host_connection_id;
 621        {
 622            let state = self.store().await;
 623            let project = state.project(project_id)?;
 624            host_user_id = project.host_user_id;
 625            host_connection_id = project.host_connection_id;
 626            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 627        };
 628
 629        tracing::info!(project_id, %host_user_id, %host_connection_id, "join project");
 630        let has_contact = self
 631            .app_state
 632            .db
 633            .has_contact(guest_user_id, host_user_id)
 634            .await?;
 635        if !has_contact {
 636            return Err(anyhow!("no such project"))?;
 637        }
 638
 639        self.store_mut().await.request_join_project(
 640            guest_user_id,
 641            project_id,
 642            response.into_receipt(),
 643        )?;
 644        self.peer.send(
 645            host_connection_id,
 646            proto::RequestJoinProject {
 647                project_id,
 648                requester_id: guest_user_id.to_proto(),
 649            },
 650        )?;
 651        Ok(())
 652    }
 653
 654    async fn respond_to_join_project_request(
 655        self: Arc<Server>,
 656        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 657    ) -> Result<()> {
 658        let host_user_id;
 659
 660        {
 661            let mut state = self.store_mut().await;
 662            let project_id = request.payload.project_id;
 663            let project = state.project(project_id)?;
 664            if project.host_connection_id != request.sender_id {
 665                Err(anyhow!("no such connection"))?;
 666            }
 667
 668            host_user_id = project.host_user_id;
 669            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 670
 671            if !request.payload.allow {
 672                let receipts = state
 673                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 674                    .ok_or_else(|| anyhow!("no such request"))?;
 675                for receipt in receipts {
 676                    self.peer.respond(
 677                        receipt,
 678                        proto::JoinProjectResponse {
 679                            variant: Some(proto::join_project_response::Variant::Decline(
 680                                proto::join_project_response::Decline {
 681                                    reason: proto::join_project_response::decline::Reason::Declined
 682                                        as i32,
 683                                },
 684                            )),
 685                        },
 686                    )?;
 687                }
 688                return Ok(());
 689            }
 690
 691            let (receipts_with_replica_ids, project) = state
 692                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 693                .ok_or_else(|| anyhow!("no such request"))?;
 694
 695            let peer_count = project.guests.len();
 696            let mut collaborators = Vec::with_capacity(peer_count);
 697            collaborators.push(proto::Collaborator {
 698                peer_id: project.host_connection_id.0,
 699                replica_id: 0,
 700                user_id: project.host_user_id.to_proto(),
 701            });
 702            let worktrees = project
 703                .worktrees
 704                .iter()
 705                .filter_map(|(id, shared_worktree)| {
 706                    let worktree = project.worktrees.get(&id)?;
 707                    Some(proto::Worktree {
 708                        id: *id,
 709                        root_name: worktree.root_name.clone(),
 710                        entries: shared_worktree.entries.values().cloned().collect(),
 711                        diagnostic_summaries: shared_worktree
 712                            .diagnostic_summaries
 713                            .values()
 714                            .cloned()
 715                            .collect(),
 716                        visible: worktree.visible,
 717                        scan_id: shared_worktree.scan_id,
 718                    })
 719                })
 720                .collect::<Vec<_>>();
 721
 722            // Add all guests other than the requesting user's own connections as collaborators
 723            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 724                if receipts_with_replica_ids
 725                    .iter()
 726                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 727                {
 728                    collaborators.push(proto::Collaborator {
 729                        peer_id: peer_conn_id.0,
 730                        replica_id: *peer_replica_id as u32,
 731                        user_id: peer_user_id.to_proto(),
 732                    });
 733                }
 734            }
 735
 736            for conn_id in project.connection_ids() {
 737                for (receipt, replica_id) in &receipts_with_replica_ids {
 738                    if conn_id != receipt.sender_id {
 739                        self.peer.send(
 740                            conn_id,
 741                            proto::AddProjectCollaborator {
 742                                project_id,
 743                                collaborator: Some(proto::Collaborator {
 744                                    peer_id: receipt.sender_id.0,
 745                                    replica_id: *replica_id as u32,
 746                                    user_id: guest_user_id.to_proto(),
 747                                }),
 748                            },
 749                        )?;
 750                    }
 751                }
 752            }
 753
 754            for (receipt, replica_id) in receipts_with_replica_ids {
 755                self.peer.respond(
 756                    receipt,
 757                    proto::JoinProjectResponse {
 758                        variant: Some(proto::join_project_response::Variant::Accept(
 759                            proto::join_project_response::Accept {
 760                                worktrees: worktrees.clone(),
 761                                replica_id: replica_id as u32,
 762                                collaborators: collaborators.clone(),
 763                                language_servers: project.language_servers.clone(),
 764                            },
 765                        )),
 766                    },
 767                )?;
 768            }
 769        }
 770
 771        self.update_user_contacts(host_user_id).await?;
 772        Ok(())
 773    }
 774
 775    async fn leave_project(
 776        self: Arc<Server>,
 777        request: TypedEnvelope<proto::LeaveProject>,
 778    ) -> Result<()> {
 779        let sender_id = request.sender_id;
 780        let project_id = request.payload.project_id;
 781        let project;
 782        {
 783            let mut store = self.store_mut().await;
 784            project = store.leave_project(sender_id, project_id)?;
 785            tracing::info!(
 786                project_id,
 787                host_user_id = %project.host_user_id,
 788                host_connection_id = %project.host_connection_id,
 789                "leave project"
 790            );
 791
 792            if project.remove_collaborator {
 793                broadcast(sender_id, project.connection_ids, |conn_id| {
 794                    self.peer.send(
 795                        conn_id,
 796                        proto::RemoveProjectCollaborator {
 797                            project_id,
 798                            peer_id: sender_id.0,
 799                        },
 800                    )
 801                });
 802            }
 803
 804            if let Some(requester_id) = project.cancel_request {
 805                self.peer.send(
 806                    project.host_connection_id,
 807                    proto::JoinProjectRequestCancelled {
 808                        project_id,
 809                        requester_id: requester_id.to_proto(),
 810                    },
 811                )?;
 812            }
 813
 814            if project.unshare {
 815                self.peer.send(
 816                    project.host_connection_id,
 817                    proto::ProjectUnshared { project_id },
 818                )?;
 819            }
 820        }
 821        self.update_user_contacts(project.host_user_id).await?;
 822        Ok(())
 823    }
 824
 825    async fn update_project(
 826        self: Arc<Server>,
 827        request: TypedEnvelope<proto::UpdateProject>,
 828    ) -> Result<()> {
 829        let user_id;
 830        {
 831            let mut state = self.store_mut().await;
 832            user_id = state.user_id_for_connection(request.sender_id)?;
 833            let guest_connection_ids = state
 834                .read_project(request.payload.project_id, request.sender_id)?
 835                .guest_connection_ids();
 836            state.update_project(
 837                request.payload.project_id,
 838                &request.payload.worktrees,
 839                request.sender_id,
 840            )?;
 841            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 842                self.peer
 843                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 844            });
 845        };
 846        self.update_user_contacts(user_id).await?;
 847        Ok(())
 848    }
 849
 850    async fn register_project_activity(
 851        self: Arc<Server>,
 852        request: TypedEnvelope<proto::RegisterProjectActivity>,
 853    ) -> Result<()> {
 854        self.store_mut()
 855            .await
 856            .register_project_activity(request.payload.project_id, request.sender_id)?;
 857        Ok(())
 858    }
 859
 860    async fn update_worktree(
 861        self: Arc<Server>,
 862        request: TypedEnvelope<proto::UpdateWorktree>,
 863        response: Response<proto::UpdateWorktree>,
 864    ) -> Result<()> {
 865        let (connection_ids, metadata_changed) = {
 866            let mut store = self.store_mut().await;
 867            let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
 868                request.sender_id,
 869                request.payload.project_id,
 870                request.payload.worktree_id,
 871                &request.payload.root_name,
 872                &request.payload.removed_entries,
 873                &request.payload.updated_entries,
 874                request.payload.scan_id,
 875            )?;
 876            for (extension, count) in extension_counts {
 877                tracing::info!(
 878                    project_id = request.payload.project_id,
 879                    worktree_id = request.payload.worktree_id,
 880                    ?extension,
 881                    %count,
 882                    "worktree updated"
 883                );
 884            }
 885            (connection_ids, metadata_changed)
 886        };
 887
 888        broadcast(request.sender_id, connection_ids, |connection_id| {
 889            self.peer
 890                .forward_send(request.sender_id, connection_id, request.payload.clone())
 891        });
 892        if metadata_changed {
 893            let user_id = self
 894                .store()
 895                .await
 896                .user_id_for_connection(request.sender_id)?;
 897            self.update_user_contacts(user_id).await?;
 898        }
 899        response.send(proto::Ack {})?;
 900        Ok(())
 901    }
 902
 903    async fn update_diagnostic_summary(
 904        self: Arc<Server>,
 905        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 906    ) -> Result<()> {
 907        let summary = request
 908            .payload
 909            .summary
 910            .clone()
 911            .ok_or_else(|| anyhow!("invalid summary"))?;
 912        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 913            request.payload.project_id,
 914            request.payload.worktree_id,
 915            request.sender_id,
 916            summary,
 917        )?;
 918
 919        broadcast(request.sender_id, receiver_ids, |connection_id| {
 920            self.peer
 921                .forward_send(request.sender_id, connection_id, request.payload.clone())
 922        });
 923        Ok(())
 924    }
 925
 926    async fn start_language_server(
 927        self: Arc<Server>,
 928        request: TypedEnvelope<proto::StartLanguageServer>,
 929    ) -> Result<()> {
 930        let receiver_ids = self.store_mut().await.start_language_server(
 931            request.payload.project_id,
 932            request.sender_id,
 933            request
 934                .payload
 935                .server
 936                .clone()
 937                .ok_or_else(|| anyhow!("invalid language server"))?,
 938        )?;
 939        broadcast(request.sender_id, receiver_ids, |connection_id| {
 940            self.peer
 941                .forward_send(request.sender_id, connection_id, request.payload.clone())
 942        });
 943        Ok(())
 944    }
 945
 946    async fn update_language_server(
 947        self: Arc<Server>,
 948        request: TypedEnvelope<proto::UpdateLanguageServer>,
 949    ) -> Result<()> {
 950        let receiver_ids = self
 951            .store()
 952            .await
 953            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 954        broadcast(request.sender_id, receiver_ids, |connection_id| {
 955            self.peer
 956                .forward_send(request.sender_id, connection_id, request.payload.clone())
 957        });
 958        Ok(())
 959    }
 960
 961    async fn forward_project_request<T>(
 962        self: Arc<Server>,
 963        request: TypedEnvelope<T>,
 964        response: Response<T>,
 965    ) -> Result<()>
 966    where
 967        T: EntityMessage + RequestMessage,
 968    {
 969        let host_connection_id = self
 970            .store()
 971            .await
 972            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 973            .host_connection_id;
 974
 975        response.send(
 976            self.peer
 977                .forward_request(request.sender_id, host_connection_id, request.payload)
 978                .await?,
 979        )?;
 980        Ok(())
 981    }
 982
 983    async fn save_buffer(
 984        self: Arc<Server>,
 985        request: TypedEnvelope<proto::SaveBuffer>,
 986        response: Response<proto::SaveBuffer>,
 987    ) -> Result<()> {
 988        let host = self
 989            .store()
 990            .await
 991            .read_project(request.payload.project_id, request.sender_id)?
 992            .host_connection_id;
 993        let response_payload = self
 994            .peer
 995            .forward_request(request.sender_id, host, request.payload.clone())
 996            .await?;
 997
 998        let mut guests = self
 999            .store()
1000            .await
1001            .read_project(request.payload.project_id, request.sender_id)?
1002            .connection_ids();
1003        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1004        broadcast(host, guests, |conn_id| {
1005            self.peer
1006                .forward_send(host, conn_id, response_payload.clone())
1007        });
1008        response.send(response_payload)?;
1009        Ok(())
1010    }
1011
1012    async fn update_buffer(
1013        self: Arc<Server>,
1014        request: TypedEnvelope<proto::UpdateBuffer>,
1015        response: Response<proto::UpdateBuffer>,
1016    ) -> Result<()> {
1017        let receiver_ids = {
1018            let mut store = self.store_mut().await;
1019            store.register_project_activity(request.payload.project_id, request.sender_id)?;
1020            store.project_connection_ids(request.payload.project_id, request.sender_id)?
1021        };
1022
1023        broadcast(request.sender_id, receiver_ids, |connection_id| {
1024            self.peer
1025                .forward_send(request.sender_id, connection_id, request.payload.clone())
1026        });
1027        response.send(proto::Ack {})?;
1028        Ok(())
1029    }
1030
1031    async fn update_buffer_file(
1032        self: Arc<Server>,
1033        request: TypedEnvelope<proto::UpdateBufferFile>,
1034    ) -> Result<()> {
1035        let receiver_ids = self
1036            .store()
1037            .await
1038            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1039        broadcast(request.sender_id, receiver_ids, |connection_id| {
1040            self.peer
1041                .forward_send(request.sender_id, connection_id, request.payload.clone())
1042        });
1043        Ok(())
1044    }
1045
1046    async fn buffer_reloaded(
1047        self: Arc<Server>,
1048        request: TypedEnvelope<proto::BufferReloaded>,
1049    ) -> Result<()> {
1050        let receiver_ids = self
1051            .store()
1052            .await
1053            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1054        broadcast(request.sender_id, receiver_ids, |connection_id| {
1055            self.peer
1056                .forward_send(request.sender_id, connection_id, request.payload.clone())
1057        });
1058        Ok(())
1059    }
1060
1061    async fn buffer_saved(
1062        self: Arc<Server>,
1063        request: TypedEnvelope<proto::BufferSaved>,
1064    ) -> Result<()> {
1065        let receiver_ids = self
1066            .store()
1067            .await
1068            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1069        broadcast(request.sender_id, receiver_ids, |connection_id| {
1070            self.peer
1071                .forward_send(request.sender_id, connection_id, request.payload.clone())
1072        });
1073        Ok(())
1074    }
1075
1076    async fn follow(
1077        self: Arc<Self>,
1078        request: TypedEnvelope<proto::Follow>,
1079        response: Response<proto::Follow>,
1080    ) -> Result<()> {
1081        let leader_id = ConnectionId(request.payload.leader_id);
1082        let follower_id = request.sender_id;
1083        {
1084            let mut store = self.store_mut().await;
1085            if !store
1086                .project_connection_ids(request.payload.project_id, follower_id)?
1087                .contains(&leader_id)
1088            {
1089                Err(anyhow!("no such peer"))?;
1090            }
1091
1092            store.register_project_activity(request.payload.project_id, follower_id)?;
1093        }
1094
1095        let mut response_payload = self
1096            .peer
1097            .forward_request(request.sender_id, leader_id, request.payload)
1098            .await?;
1099        response_payload
1100            .views
1101            .retain(|view| view.leader_id != Some(follower_id.0));
1102        response.send(response_payload)?;
1103        Ok(())
1104    }
1105
1106    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1107        let leader_id = ConnectionId(request.payload.leader_id);
1108        let mut store = self.store_mut().await;
1109        if !store
1110            .project_connection_ids(request.payload.project_id, request.sender_id)?
1111            .contains(&leader_id)
1112        {
1113            Err(anyhow!("no such peer"))?;
1114        }
1115        store.register_project_activity(request.payload.project_id, request.sender_id)?;
1116        self.peer
1117            .forward_send(request.sender_id, leader_id, request.payload)?;
1118        Ok(())
1119    }
1120
1121    async fn update_followers(
1122        self: Arc<Self>,
1123        request: TypedEnvelope<proto::UpdateFollowers>,
1124    ) -> Result<()> {
1125        let mut store = self.store_mut().await;
1126        store.register_project_activity(request.payload.project_id, request.sender_id)?;
1127        let connection_ids =
1128            store.project_connection_ids(request.payload.project_id, request.sender_id)?;
1129        let leader_id = request
1130            .payload
1131            .variant
1132            .as_ref()
1133            .and_then(|variant| match variant {
1134                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1135                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1136                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1137            });
1138        for follower_id in &request.payload.follower_ids {
1139            let follower_id = ConnectionId(*follower_id);
1140            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1141                self.peer
1142                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1143            }
1144        }
1145        Ok(())
1146    }
1147
1148    async fn get_channels(
1149        self: Arc<Server>,
1150        request: TypedEnvelope<proto::GetChannels>,
1151        response: Response<proto::GetChannels>,
1152    ) -> Result<()> {
1153        let user_id = self
1154            .store()
1155            .await
1156            .user_id_for_connection(request.sender_id)?;
1157        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1158        response.send(proto::GetChannelsResponse {
1159            channels: channels
1160                .into_iter()
1161                .map(|chan| proto::Channel {
1162                    id: chan.id.to_proto(),
1163                    name: chan.name,
1164                })
1165                .collect(),
1166        })?;
1167        Ok(())
1168    }
1169
1170    async fn get_users(
1171        self: Arc<Server>,
1172        request: TypedEnvelope<proto::GetUsers>,
1173        response: Response<proto::GetUsers>,
1174    ) -> Result<()> {
1175        let user_ids = request
1176            .payload
1177            .user_ids
1178            .into_iter()
1179            .map(UserId::from_proto)
1180            .collect();
1181        let users = self
1182            .app_state
1183            .db
1184            .get_users_by_ids(user_ids)
1185            .await?
1186            .into_iter()
1187            .map(|user| proto::User {
1188                id: user.id.to_proto(),
1189                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1190                github_login: user.github_login,
1191            })
1192            .collect();
1193        response.send(proto::UsersResponse { users })?;
1194        Ok(())
1195    }
1196
1197    async fn fuzzy_search_users(
1198        self: Arc<Server>,
1199        request: TypedEnvelope<proto::FuzzySearchUsers>,
1200        response: Response<proto::FuzzySearchUsers>,
1201    ) -> Result<()> {
1202        let user_id = self
1203            .store()
1204            .await
1205            .user_id_for_connection(request.sender_id)?;
1206        let query = request.payload.query;
1207        let db = &self.app_state.db;
1208        let users = match query.len() {
1209            0 => vec![],
1210            1 | 2 => db
1211                .get_user_by_github_login(&query)
1212                .await?
1213                .into_iter()
1214                .collect(),
1215            _ => db.fuzzy_search_users(&query, 10).await?,
1216        };
1217        let users = users
1218            .into_iter()
1219            .filter(|user| user.id != user_id)
1220            .map(|user| proto::User {
1221                id: user.id.to_proto(),
1222                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1223                github_login: user.github_login,
1224            })
1225            .collect();
1226        response.send(proto::UsersResponse { users })?;
1227        Ok(())
1228    }
1229
1230    async fn request_contact(
1231        self: Arc<Server>,
1232        request: TypedEnvelope<proto::RequestContact>,
1233        response: Response<proto::RequestContact>,
1234    ) -> Result<()> {
1235        let requester_id = self
1236            .store()
1237            .await
1238            .user_id_for_connection(request.sender_id)?;
1239        let responder_id = UserId::from_proto(request.payload.responder_id);
1240        if requester_id == responder_id {
1241            return Err(anyhow!("cannot add yourself as a contact"))?;
1242        }
1243
1244        self.app_state
1245            .db
1246            .send_contact_request(requester_id, responder_id)
1247            .await?;
1248
1249        // Update outgoing contact requests of requester
1250        let mut update = proto::UpdateContacts::default();
1251        update.outgoing_requests.push(responder_id.to_proto());
1252        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1253            self.peer.send(connection_id, update.clone())?;
1254        }
1255
1256        // Update incoming contact requests of responder
1257        let mut update = proto::UpdateContacts::default();
1258        update
1259            .incoming_requests
1260            .push(proto::IncomingContactRequest {
1261                requester_id: requester_id.to_proto(),
1262                should_notify: true,
1263            });
1264        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1265            self.peer.send(connection_id, update.clone())?;
1266        }
1267
1268        response.send(proto::Ack {})?;
1269        Ok(())
1270    }
1271
1272    async fn respond_to_contact_request(
1273        self: Arc<Server>,
1274        request: TypedEnvelope<proto::RespondToContactRequest>,
1275        response: Response<proto::RespondToContactRequest>,
1276    ) -> Result<()> {
1277        let responder_id = self
1278            .store()
1279            .await
1280            .user_id_for_connection(request.sender_id)?;
1281        let requester_id = UserId::from_proto(request.payload.requester_id);
1282        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1283            self.app_state
1284                .db
1285                .dismiss_contact_notification(responder_id, requester_id)
1286                .await?;
1287        } else {
1288            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1289            self.app_state
1290                .db
1291                .respond_to_contact_request(responder_id, requester_id, accept)
1292                .await?;
1293
1294            let store = self.store().await;
1295            // Update responder with new contact
1296            let mut update = proto::UpdateContacts::default();
1297            if accept {
1298                update
1299                    .contacts
1300                    .push(store.contact_for_user(requester_id, false));
1301            }
1302            update
1303                .remove_incoming_requests
1304                .push(requester_id.to_proto());
1305            for connection_id in store.connection_ids_for_user(responder_id) {
1306                self.peer.send(connection_id, update.clone())?;
1307            }
1308
1309            // Update requester with new contact
1310            let mut update = proto::UpdateContacts::default();
1311            if accept {
1312                update
1313                    .contacts
1314                    .push(store.contact_for_user(responder_id, true));
1315            }
1316            update
1317                .remove_outgoing_requests
1318                .push(responder_id.to_proto());
1319            for connection_id in store.connection_ids_for_user(requester_id) {
1320                self.peer.send(connection_id, update.clone())?;
1321            }
1322        }
1323
1324        response.send(proto::Ack {})?;
1325        Ok(())
1326    }
1327
1328    async fn remove_contact(
1329        self: Arc<Server>,
1330        request: TypedEnvelope<proto::RemoveContact>,
1331        response: Response<proto::RemoveContact>,
1332    ) -> Result<()> {
1333        let requester_id = self
1334            .store()
1335            .await
1336            .user_id_for_connection(request.sender_id)?;
1337        let responder_id = UserId::from_proto(request.payload.user_id);
1338        self.app_state
1339            .db
1340            .remove_contact(requester_id, responder_id)
1341            .await?;
1342
1343        // Update outgoing contact requests of requester
1344        let mut update = proto::UpdateContacts::default();
1345        update
1346            .remove_outgoing_requests
1347            .push(responder_id.to_proto());
1348        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1349            self.peer.send(connection_id, update.clone())?;
1350        }
1351
1352        // Update incoming contact requests of responder
1353        let mut update = proto::UpdateContacts::default();
1354        update
1355            .remove_incoming_requests
1356            .push(requester_id.to_proto());
1357        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1358            self.peer.send(connection_id, update.clone())?;
1359        }
1360
1361        response.send(proto::Ack {})?;
1362        Ok(())
1363    }
1364
1365    async fn join_channel(
1366        self: Arc<Self>,
1367        request: TypedEnvelope<proto::JoinChannel>,
1368        response: Response<proto::JoinChannel>,
1369    ) -> Result<()> {
1370        let user_id = self
1371            .store()
1372            .await
1373            .user_id_for_connection(request.sender_id)?;
1374        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1375        if !self
1376            .app_state
1377            .db
1378            .can_user_access_channel(user_id, channel_id)
1379            .await?
1380        {
1381            Err(anyhow!("access denied"))?;
1382        }
1383
1384        self.store_mut()
1385            .await
1386            .join_channel(request.sender_id, channel_id);
1387        let messages = self
1388            .app_state
1389            .db
1390            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1391            .await?
1392            .into_iter()
1393            .map(|msg| proto::ChannelMessage {
1394                id: msg.id.to_proto(),
1395                body: msg.body,
1396                timestamp: msg.sent_at.unix_timestamp() as u64,
1397                sender_id: msg.sender_id.to_proto(),
1398                nonce: Some(msg.nonce.as_u128().into()),
1399            })
1400            .collect::<Vec<_>>();
1401        response.send(proto::JoinChannelResponse {
1402            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1403            messages,
1404        })?;
1405        Ok(())
1406    }
1407
1408    async fn leave_channel(
1409        self: Arc<Self>,
1410        request: TypedEnvelope<proto::LeaveChannel>,
1411    ) -> Result<()> {
1412        let user_id = self
1413            .store()
1414            .await
1415            .user_id_for_connection(request.sender_id)?;
1416        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1417        if !self
1418            .app_state
1419            .db
1420            .can_user_access_channel(user_id, channel_id)
1421            .await?
1422        {
1423            Err(anyhow!("access denied"))?;
1424        }
1425
1426        self.store_mut()
1427            .await
1428            .leave_channel(request.sender_id, channel_id);
1429
1430        Ok(())
1431    }
1432
1433    async fn send_channel_message(
1434        self: Arc<Self>,
1435        request: TypedEnvelope<proto::SendChannelMessage>,
1436        response: Response<proto::SendChannelMessage>,
1437    ) -> Result<()> {
1438        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1439        let user_id;
1440        let connection_ids;
1441        {
1442            let state = self.store().await;
1443            user_id = state.user_id_for_connection(request.sender_id)?;
1444            connection_ids = state.channel_connection_ids(channel_id)?;
1445        }
1446
1447        // Validate the message body.
1448        let body = request.payload.body.trim().to_string();
1449        if body.len() > MAX_MESSAGE_LEN {
1450            return Err(anyhow!("message is too long"))?;
1451        }
1452        if body.is_empty() {
1453            return Err(anyhow!("message can't be blank"))?;
1454        }
1455
1456        let timestamp = OffsetDateTime::now_utc();
1457        let nonce = request
1458            .payload
1459            .nonce
1460            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1461
1462        let message_id = self
1463            .app_state
1464            .db
1465            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1466            .await?
1467            .to_proto();
1468        let message = proto::ChannelMessage {
1469            sender_id: user_id.to_proto(),
1470            id: message_id,
1471            body,
1472            timestamp: timestamp.unix_timestamp() as u64,
1473            nonce: Some(nonce),
1474        };
1475        broadcast(request.sender_id, connection_ids, |conn_id| {
1476            self.peer.send(
1477                conn_id,
1478                proto::ChannelMessageSent {
1479                    channel_id: channel_id.to_proto(),
1480                    message: Some(message.clone()),
1481                },
1482            )
1483        });
1484        response.send(proto::SendChannelMessageResponse {
1485            message: Some(message),
1486        })?;
1487        Ok(())
1488    }
1489
1490    async fn get_channel_messages(
1491        self: Arc<Self>,
1492        request: TypedEnvelope<proto::GetChannelMessages>,
1493        response: Response<proto::GetChannelMessages>,
1494    ) -> Result<()> {
1495        let user_id = self
1496            .store()
1497            .await
1498            .user_id_for_connection(request.sender_id)?;
1499        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1500        if !self
1501            .app_state
1502            .db
1503            .can_user_access_channel(user_id, channel_id)
1504            .await?
1505        {
1506            Err(anyhow!("access denied"))?;
1507        }
1508
1509        let messages = self
1510            .app_state
1511            .db
1512            .get_channel_messages(
1513                channel_id,
1514                MESSAGE_COUNT_PER_PAGE,
1515                Some(MessageId::from_proto(request.payload.before_message_id)),
1516            )
1517            .await?
1518            .into_iter()
1519            .map(|msg| proto::ChannelMessage {
1520                id: msg.id.to_proto(),
1521                body: msg.body,
1522                timestamp: msg.sent_at.unix_timestamp() as u64,
1523                sender_id: msg.sender_id.to_proto(),
1524                nonce: Some(msg.nonce.as_u128().into()),
1525            })
1526            .collect::<Vec<_>>();
1527        response.send(proto::GetChannelMessagesResponse {
1528            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1529            messages,
1530        })?;
1531        Ok(())
1532    }
1533
1534    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1535        #[cfg(test)]
1536        tokio::task::yield_now().await;
1537        let guard = self.store.read().await;
1538        #[cfg(test)]
1539        tokio::task::yield_now().await;
1540        StoreReadGuard {
1541            guard,
1542            _not_send: PhantomData,
1543        }
1544    }
1545
1546    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1547        #[cfg(test)]
1548        tokio::task::yield_now().await;
1549        let guard = self.store.write().await;
1550        #[cfg(test)]
1551        tokio::task::yield_now().await;
1552        StoreWriteGuard {
1553            guard,
1554            _not_send: PhantomData,
1555        }
1556    }
1557
1558    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1559        ServerSnapshot {
1560            store: self.store.read().await,
1561            peer: &self.peer,
1562        }
1563    }
1564}
1565
1566impl<'a> Deref for StoreReadGuard<'a> {
1567    type Target = Store;
1568
1569    fn deref(&self) -> &Self::Target {
1570        &*self.guard
1571    }
1572}
1573
1574impl<'a> Deref for StoreWriteGuard<'a> {
1575    type Target = Store;
1576
1577    fn deref(&self) -> &Self::Target {
1578        &*self.guard
1579    }
1580}
1581
1582impl<'a> DerefMut for StoreWriteGuard<'a> {
1583    fn deref_mut(&mut self) -> &mut Self::Target {
1584        &mut *self.guard
1585    }
1586}
1587
1588impl<'a> Drop for StoreWriteGuard<'a> {
1589    fn drop(&mut self) {
1590        #[cfg(test)]
1591        self.check_invariants();
1592
1593        let metrics = self.metrics();
1594
1595        METRIC_CONNECTIONS.set(metrics.connections as _);
1596        METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1597        METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1598        METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1599
1600        tracing::info!(
1601            connections = metrics.connections,
1602            registered_projects = metrics.registered_projects,
1603            active_projects = metrics.active_projects,
1604            shared_projects = metrics.shared_projects,
1605            "metrics"
1606        );
1607    }
1608}
1609
1610impl Executor for RealExecutor {
1611    type Sleep = Sleep;
1612
1613    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1614        tokio::task::spawn(future);
1615    }
1616
1617    fn sleep(&self, duration: Duration) -> Self::Sleep {
1618        tokio::time::sleep(duration)
1619    }
1620}
1621
1622fn broadcast<F>(
1623    sender_id: ConnectionId,
1624    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1625    mut f: F,
1626) where
1627    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1628{
1629    for receiver_id in receiver_ids {
1630        if receiver_id != sender_id {
1631            f(receiver_id).trace_err();
1632        }
1633    }
1634}
1635
1636lazy_static! {
1637    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1638}
1639
1640pub struct ProtocolVersion(u32);
1641
1642impl Header for ProtocolVersion {
1643    fn name() -> &'static HeaderName {
1644        &ZED_PROTOCOL_VERSION
1645    }
1646
1647    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1648    where
1649        Self: Sized,
1650        I: Iterator<Item = &'i axum::http::HeaderValue>,
1651    {
1652        let version = values
1653            .next()
1654            .ok_or_else(|| axum::headers::Error::invalid())?
1655            .to_str()
1656            .map_err(|_| axum::headers::Error::invalid())?
1657            .parse()
1658            .map_err(|_| axum::headers::Error::invalid())?;
1659        Ok(Self(version))
1660    }
1661
1662    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1663        values.extend([self.0.to_string().parse().unwrap()]);
1664    }
1665}
1666
1667pub fn routes(server: Arc<Server>) -> Router<Body> {
1668    Router::new()
1669        .route("/rpc", get(handle_websocket_request))
1670        .layer(
1671            ServiceBuilder::new()
1672                .layer(Extension(server.app_state.clone()))
1673                .layer(middleware::from_fn(auth::validate_header)),
1674        )
1675        .route("/metrics", get(handle_metrics))
1676        .layer(Extension(server))
1677}
1678
1679pub async fn handle_websocket_request(
1680    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1681    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1682    Extension(server): Extension<Arc<Server>>,
1683    Extension(user): Extension<User>,
1684    ws: WebSocketUpgrade,
1685) -> axum::response::Response {
1686    if protocol_version != rpc::PROTOCOL_VERSION {
1687        return (
1688            StatusCode::UPGRADE_REQUIRED,
1689            "client must be upgraded".to_string(),
1690        )
1691            .into_response();
1692    }
1693    let socket_address = socket_address.to_string();
1694    ws.on_upgrade(move |socket| {
1695        use util::ResultExt;
1696        let socket = socket
1697            .map_ok(to_tungstenite_message)
1698            .err_into()
1699            .with(|message| async move { Ok(to_axum_message(message)) });
1700        let connection = Connection::new(Box::pin(socket));
1701        async move {
1702            server
1703                .handle_connection(connection, socket_address, user, None, RealExecutor)
1704                .await
1705                .log_err();
1706        }
1707    })
1708}
1709
1710pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1711    // We call `store_mut` here for its side effects of updating metrics.
1712    server.store_mut().await;
1713
1714    let encoder = prometheus::TextEncoder::new();
1715    let metric_families = prometheus::gather();
1716    match encoder.encode_to_string(&metric_families) {
1717        Ok(string) => (StatusCode::OK, string).into_response(),
1718        Err(error) => (
1719            StatusCode::INTERNAL_SERVER_ERROR,
1720            format!("failed to encode metrics {:?}", error),
1721        )
1722            .into_response(),
1723    }
1724}
1725
1726fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1727    match message {
1728        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1729        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1730        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1731        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1732        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1733            code: frame.code.into(),
1734            reason: frame.reason,
1735        })),
1736    }
1737}
1738
1739fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1740    match message {
1741        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1742        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1743        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1744        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1745        AxumMessage::Close(frame) => {
1746            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1747                code: frame.code.into(),
1748                reason: frame.reason,
1749            }))
1750        }
1751    }
1752}
1753
1754pub trait ResultExt {
1755    type Ok;
1756
1757    fn trace_err(self) -> Option<Self::Ok>;
1758}
1759
1760impl<T, E> ResultExt for Result<T, E>
1761where
1762    E: std::fmt::Debug,
1763{
1764    type Ok = T;
1765
1766    fn trace_err(self) -> Option<T> {
1767        match self {
1768            Ok(value) => Some(value),
1769            Err(error) => {
1770                tracing::error!("{:?}", error);
1771                None
1772            }
1773        }
1774    }
1775}