rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use rpc::{
  33    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  34    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  35};
  36use serde::{Serialize, Serializer};
  37use std::{
  38    any::TypeId,
  39    future::Future,
  40    marker::PhantomData,
  41    net::SocketAddr,
  42    ops::{Deref, DerefMut},
  43    rc::Rc,
  44    sync::{
  45        atomic::{AtomicBool, Ordering::SeqCst},
  46        Arc,
  47    },
  48    time::Duration,
  49};
  50use time::OffsetDateTime;
  51use tokio::{
  52    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  53    time::Sleep,
  54};
  55use tower::ServiceBuilder;
  56use tracing::{info_span, instrument, Instrument};
  57
  58pub use store::{Store, Worktree};
  59
  60type MessageHandler =
  61    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  62
  63struct Response<R> {
  64    server: Arc<Server>,
  65    receipt: Receipt<R>,
  66    responded: Arc<AtomicBool>,
  67}
  68
  69impl<R: RequestMessage> Response<R> {
  70    fn send(self, payload: R::Response) -> Result<()> {
  71        self.responded.store(true, SeqCst);
  72        self.server.peer.respond(self.receipt, payload)?;
  73        Ok(())
  74    }
  75
  76    fn into_receipt(self) -> Receipt<R> {
  77        self.responded.store(true, SeqCst);
  78        self.receipt
  79    }
  80}
  81
  82pub struct Server {
  83    peer: Arc<Peer>,
  84    pub(crate) store: RwLock<Store>,
  85    app_state: Arc<AppState>,
  86    handlers: HashMap<TypeId, MessageHandler>,
  87    notifications: Option<mpsc::UnboundedSender<()>>,
  88}
  89
  90pub trait Executor: Send + Clone {
  91    type Sleep: Send + Future;
  92    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
  93    fn sleep(&self, duration: Duration) -> Self::Sleep;
  94}
  95
  96#[derive(Clone)]
  97pub struct RealExecutor;
  98
  99const MESSAGE_COUNT_PER_PAGE: usize = 100;
 100const MAX_MESSAGE_LEN: usize = 1024;
 101
 102struct StoreReadGuard<'a> {
 103    guard: RwLockReadGuard<'a, Store>,
 104    _not_send: PhantomData<Rc<()>>,
 105}
 106
 107struct StoreWriteGuard<'a> {
 108    guard: RwLockWriteGuard<'a, Store>,
 109    _not_send: PhantomData<Rc<()>>,
 110}
 111
 112#[derive(Serialize)]
 113pub struct ServerSnapshot<'a> {
 114    peer: &'a Peer,
 115    #[serde(serialize_with = "serialize_deref")]
 116    store: RwLockReadGuard<'a, Store>,
 117}
 118
 119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 120where
 121    S: Serializer,
 122    T: Deref<Target = U>,
 123    U: Serialize,
 124{
 125    Serialize::serialize(value.deref(), serializer)
 126}
 127
 128impl Server {
 129    pub fn new(
 130        app_state: Arc<AppState>,
 131        notifications: Option<mpsc::UnboundedSender<()>>,
 132    ) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138            notifications,
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::register_project)
 144            .add_request_handler(Server::unregister_project)
 145            .add_request_handler(Server::join_project)
 146            .add_message_handler(Server::leave_project)
 147            .add_message_handler(Server::respond_to_join_project_request)
 148            .add_message_handler(Server::update_project)
 149            .add_request_handler(Server::update_worktree)
 150            .add_message_handler(Server::start_language_server)
 151            .add_message_handler(Server::update_language_server)
 152            .add_message_handler(Server::update_diagnostic_summary)
 153            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 154            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 155            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 156            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 157            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 158            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 159            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 160            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 161            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 163            .add_request_handler(
 164                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 165            )
 166            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 167            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 168            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 169            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 170            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 171            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 172            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 173            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 174            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 175            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 176            .add_request_handler(Server::update_buffer)
 177            .add_message_handler(Server::update_buffer_file)
 178            .add_message_handler(Server::buffer_reloaded)
 179            .add_message_handler(Server::buffer_saved)
 180            .add_request_handler(Server::save_buffer)
 181            .add_request_handler(Server::get_channels)
 182            .add_request_handler(Server::get_users)
 183            .add_request_handler(Server::fuzzy_search_users)
 184            .add_request_handler(Server::request_contact)
 185            .add_request_handler(Server::remove_contact)
 186            .add_request_handler(Server::respond_to_contact_request)
 187            .add_request_handler(Server::join_channel)
 188            .add_message_handler(Server::leave_channel)
 189            .add_request_handler(Server::send_channel_message)
 190            .add_request_handler(Server::follow)
 191            .add_message_handler(Server::unfollow)
 192            .add_message_handler(Server::update_followers)
 193            .add_request_handler(Server::get_channel_messages);
 194
 195        Arc::new(server)
 196    }
 197
 198    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 199    where
 200        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 201        Fut: 'static + Send + Future<Output = Result<()>>,
 202        M: EnvelopedMessage,
 203    {
 204        let prev_handler = self.handlers.insert(
 205            TypeId::of::<M>(),
 206            Box::new(move |server, envelope| {
 207                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 208                let span = info_span!(
 209                    "handle message",
 210                    payload_type = envelope.payload_type_name(),
 211                    payload = format!("{:?}", envelope.payload).as_str(),
 212                );
 213                let future = (handler)(server, *envelope);
 214                async move {
 215                    if let Err(error) = future.await {
 216                        tracing::error!(%error, "error handling message");
 217                    }
 218                }
 219                .instrument(span)
 220                .boxed()
 221            }),
 222        );
 223        if prev_handler.is_some() {
 224            panic!("registered a handler for the same message twice");
 225        }
 226        self
 227    }
 228
 229    /// Handle a request while holding a lock to the store. This is useful when we're registering
 230    /// a connection but we want to respond on the connection before anybody else can send on it.
 231    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 232    where
 233        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 234        Fut: Send + Future<Output = Result<()>>,
 235        M: RequestMessage,
 236    {
 237        let handler = Arc::new(handler);
 238        self.add_message_handler(move |server, envelope| {
 239            let receipt = envelope.receipt();
 240            let handler = handler.clone();
 241            async move {
 242                let responded = Arc::new(AtomicBool::default());
 243                let response = Response {
 244                    server: server.clone(),
 245                    responded: responded.clone(),
 246                    receipt: envelope.receipt(),
 247                };
 248                match (handler)(server.clone(), envelope, response).await {
 249                    Ok(()) => {
 250                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 251                            Ok(())
 252                        } else {
 253                            Err(anyhow!("handler did not send a response"))?
 254                        }
 255                    }
 256                    Err(error) => {
 257                        server.peer.respond_with_error(
 258                            receipt,
 259                            proto::Error {
 260                                message: error.to_string(),
 261                            },
 262                        )?;
 263                        Err(error)
 264                    }
 265                }
 266            }
 267        })
 268    }
 269
 270    pub fn handle_connection<E: Executor>(
 271        self: &Arc<Self>,
 272        connection: Connection,
 273        address: String,
 274        user: User,
 275        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 276        executor: E,
 277    ) -> impl Future<Output = Result<()>> {
 278        let mut this = self.clone();
 279        let user_id = user.id;
 280        let login = user.github_login;
 281        let span = info_span!("handle connection", %user_id, %login, %address);
 282        async move {
 283            let (connection_id, handle_io, mut incoming_rx) = this
 284                .peer
 285                .add_connection(connection, {
 286                    let executor = executor.clone();
 287                    move |duration| {
 288                        let timer = executor.sleep(duration);
 289                        async move {
 290                            timer.await;
 291                        }
 292                    }
 293                })
 294                .await;
 295
 296            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 297
 298            if let Some(send_connection_id) = send_connection_id.as_mut() {
 299                let _ = send_connection_id.send(connection_id).await;
 300            }
 301
 302            if !user.connected_once {
 303                this.peer.send(connection_id, proto::ShowContacts {})?;
 304                this.app_state.db.set_user_connected_once(user_id, true).await?;
 305            }
 306
 307            let (contacts, invite_code) = future::try_join(
 308                this.app_state.db.get_contacts(user_id),
 309                this.app_state.db.get_invite_code_for_user(user_id)
 310            ).await?;
 311
 312            {
 313                let mut store = this.store_mut().await;
 314                store.add_connection(connection_id, user_id);
 315                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 316
 317                if let Some((code, count)) = invite_code {
 318                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 319                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 320                        count,
 321                    })?;
 322                }
 323            }
 324            this.update_user_contacts(user_id).await?;
 325
 326            let handle_io = handle_io.fuse();
 327            futures::pin_mut!(handle_io);
 328            loop {
 329                let next_message = incoming_rx.next().fuse();
 330                futures::pin_mut!(next_message);
 331                futures::select_biased! {
 332                    result = handle_io => {
 333                        if let Err(error) = result {
 334                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 335                        }
 336                        break;
 337                    }
 338                    message = next_message => {
 339                        if let Some(message) = message {
 340                            let type_name = message.payload_type_name();
 341                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 342                            async {
 343                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 344                                    let notifications = this.notifications.clone();
 345                                    let is_background = message.is_background();
 346                                    let handle_message = (handler)(this.clone(), message);
 347                                    let handle_message = async move {
 348                                        handle_message.await;
 349                                        if let Some(mut notifications) = notifications {
 350                                            let _ = notifications.send(()).await;
 351                                        }
 352                                    };
 353                                    if is_background {
 354                                        executor.spawn_detached(handle_message);
 355                                    } else {
 356                                        handle_message.await;
 357                                    }
 358                                } else {
 359                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 360                                }
 361                            }.instrument(span).await;
 362                        } else {
 363                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 364                            break;
 365                        }
 366                    }
 367                }
 368            }
 369
 370            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 371            if let Err(error) = this.sign_out(connection_id).await {
 372                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 373            }
 374
 375            Ok(())
 376        }.instrument(span)
 377    }
 378
 379    #[instrument(skip(self), err)]
 380    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 381        self.peer.disconnect(connection_id);
 382
 383        let removed_user_id = {
 384            let mut store = self.store_mut().await;
 385            let removed_connection = store.remove_connection(connection_id)?;
 386
 387            for (project_id, project) in removed_connection.hosted_projects {
 388                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 389                    self.peer
 390                        .send(conn_id, proto::UnregisterProject { project_id })
 391                });
 392
 393                for (_, receipts) in project.join_requests {
 394                    for receipt in receipts {
 395                        self.peer.respond(
 396                            receipt,
 397                            proto::JoinProjectResponse {
 398                                variant: Some(proto::join_project_response::Variant::Decline(
 399                                    proto::join_project_response::Decline {
 400                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 401                                    },
 402                                )),
 403                            },
 404                        )?;
 405                    }
 406                }
 407            }
 408
 409            for project_id in removed_connection.guest_project_ids {
 410                if let Some(project) = store.project(project_id).trace_err() {
 411                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 412                        self.peer.send(
 413                            conn_id,
 414                            proto::RemoveProjectCollaborator {
 415                                project_id,
 416                                peer_id: connection_id.0,
 417                            },
 418                        )
 419                    });
 420                    if project.guests.is_empty() {
 421                        self.peer
 422                            .send(
 423                                project.host_connection_id,
 424                                proto::ProjectUnshared { project_id },
 425                            )
 426                            .trace_err();
 427                    }
 428                }
 429            }
 430
 431            removed_connection.user_id
 432        };
 433
 434        self.update_user_contacts(removed_user_id).await?;
 435
 436        Ok(())
 437    }
 438
 439    pub async fn invite_code_redeemed(
 440        self: &Arc<Self>,
 441        code: &str,
 442        invitee_id: UserId,
 443    ) -> Result<()> {
 444        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 445        let store = self.store().await;
 446        let invitee_contact = store.contact_for_user(invitee_id, true);
 447        for connection_id in store.connection_ids_for_user(user.id) {
 448            self.peer.send(
 449                connection_id,
 450                proto::UpdateContacts {
 451                    contacts: vec![invitee_contact.clone()],
 452                    ..Default::default()
 453                },
 454            )?;
 455            self.peer.send(
 456                connection_id,
 457                proto::UpdateInviteInfo {
 458                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 459                    count: user.invite_count as u32,
 460                },
 461            )?;
 462        }
 463        Ok(())
 464    }
 465
 466    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 467        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 468            if let Some(invite_code) = &user.invite_code {
 469                let store = self.store().await;
 470                for connection_id in store.connection_ids_for_user(user_id) {
 471                    self.peer.send(
 472                        connection_id,
 473                        proto::UpdateInviteInfo {
 474                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 475                            count: user.invite_count as u32,
 476                        },
 477                    )?;
 478                }
 479            }
 480        }
 481        Ok(())
 482    }
 483
 484    async fn ping(
 485        self: Arc<Server>,
 486        _: TypedEnvelope<proto::Ping>,
 487        response: Response<proto::Ping>,
 488    ) -> Result<()> {
 489        response.send(proto::Ack {})?;
 490        Ok(())
 491    }
 492
 493    async fn register_project(
 494        self: Arc<Server>,
 495        request: TypedEnvelope<proto::RegisterProject>,
 496        response: Response<proto::RegisterProject>,
 497    ) -> Result<()> {
 498        let project_id;
 499        {
 500            let mut state = self.store_mut().await;
 501            let user_id = state.user_id_for_connection(request.sender_id)?;
 502            project_id = state.register_project(request.sender_id, user_id);
 503        };
 504
 505        response.send(proto::RegisterProjectResponse { project_id })?;
 506
 507        Ok(())
 508    }
 509
 510    async fn unregister_project(
 511        self: Arc<Server>,
 512        request: TypedEnvelope<proto::UnregisterProject>,
 513        response: Response<proto::UnregisterProject>,
 514    ) -> Result<()> {
 515        let (user_id, project) = {
 516            let mut state = self.store_mut().await;
 517            let project =
 518                state.unregister_project(request.payload.project_id, request.sender_id)?;
 519            (state.user_id_for_connection(request.sender_id)?, project)
 520        };
 521
 522        broadcast(
 523            request.sender_id,
 524            project.guests.keys().copied(),
 525            |conn_id| {
 526                self.peer.send(
 527                    conn_id,
 528                    proto::UnregisterProject {
 529                        project_id: request.payload.project_id,
 530                    },
 531                )
 532            },
 533        );
 534        for (_, receipts) in project.join_requests {
 535            for receipt in receipts {
 536                self.peer.respond(
 537                    receipt,
 538                    proto::JoinProjectResponse {
 539                        variant: Some(proto::join_project_response::Variant::Decline(
 540                            proto::join_project_response::Decline {
 541                                reason: proto::join_project_response::decline::Reason::Closed
 542                                    as i32,
 543                            },
 544                        )),
 545                    },
 546                )?;
 547            }
 548        }
 549
 550        // Send out the `UpdateContacts` message before responding to the unregister
 551        // request. This way, when the project's host can keep track of the project's
 552        // remote id until after they've received the `UpdateContacts` message for
 553        // themself.
 554        self.update_user_contacts(user_id).await?;
 555        response.send(proto::Ack {})?;
 556
 557        Ok(())
 558    }
 559
 560    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 561        let contacts = self.app_state.db.get_contacts(user_id).await?;
 562        let store = self.store().await;
 563        let updated_contact = store.contact_for_user(user_id, false);
 564        for contact in contacts {
 565            if let db::Contact::Accepted {
 566                user_id: contact_user_id,
 567                ..
 568            } = contact
 569            {
 570                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 571                    self.peer
 572                        .send(
 573                            contact_conn_id,
 574                            proto::UpdateContacts {
 575                                contacts: vec![updated_contact.clone()],
 576                                remove_contacts: Default::default(),
 577                                incoming_requests: Default::default(),
 578                                remove_incoming_requests: Default::default(),
 579                                outgoing_requests: Default::default(),
 580                                remove_outgoing_requests: Default::default(),
 581                            },
 582                        )
 583                        .trace_err();
 584                }
 585            }
 586        }
 587        Ok(())
 588    }
 589
 590    async fn join_project(
 591        self: Arc<Server>,
 592        request: TypedEnvelope<proto::JoinProject>,
 593        response: Response<proto::JoinProject>,
 594    ) -> Result<()> {
 595        let project_id = request.payload.project_id;
 596
 597        let host_user_id;
 598        let guest_user_id;
 599        let host_connection_id;
 600        {
 601            let state = self.store().await;
 602            let project = state.project(project_id)?;
 603            host_user_id = project.host_user_id;
 604            host_connection_id = project.host_connection_id;
 605            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 606        };
 607
 608        tracing::info!(project_id, %host_user_id, %host_connection_id, "join project");
 609        let has_contact = self
 610            .app_state
 611            .db
 612            .has_contact(guest_user_id, host_user_id)
 613            .await?;
 614        if !has_contact {
 615            return Err(anyhow!("no such project"))?;
 616        }
 617
 618        self.store_mut().await.request_join_project(
 619            guest_user_id,
 620            project_id,
 621            response.into_receipt(),
 622        )?;
 623        self.peer.send(
 624            host_connection_id,
 625            proto::RequestJoinProject {
 626                project_id,
 627                requester_id: guest_user_id.to_proto(),
 628            },
 629        )?;
 630        Ok(())
 631    }
 632
 633    async fn respond_to_join_project_request(
 634        self: Arc<Server>,
 635        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 636    ) -> Result<()> {
 637        let host_user_id;
 638
 639        {
 640            let mut state = self.store_mut().await;
 641            let project_id = request.payload.project_id;
 642            let project = state.project(project_id)?;
 643            if project.host_connection_id != request.sender_id {
 644                Err(anyhow!("no such connection"))?;
 645            }
 646
 647            host_user_id = project.host_user_id;
 648            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 649
 650            if !request.payload.allow {
 651                let receipts = state
 652                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 653                    .ok_or_else(|| anyhow!("no such request"))?;
 654                for receipt in receipts {
 655                    self.peer.respond(
 656                        receipt,
 657                        proto::JoinProjectResponse {
 658                            variant: Some(proto::join_project_response::Variant::Decline(
 659                                proto::join_project_response::Decline {
 660                                    reason: proto::join_project_response::decline::Reason::Declined
 661                                        as i32,
 662                                },
 663                            )),
 664                        },
 665                    )?;
 666                }
 667                return Ok(());
 668            }
 669
 670            let (receipts_with_replica_ids, project) = state
 671                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 672                .ok_or_else(|| anyhow!("no such request"))?;
 673
 674            let peer_count = project.guests.len();
 675            let mut collaborators = Vec::with_capacity(peer_count);
 676            collaborators.push(proto::Collaborator {
 677                peer_id: project.host_connection_id.0,
 678                replica_id: 0,
 679                user_id: project.host_user_id.to_proto(),
 680            });
 681            let worktrees = project
 682                .worktrees
 683                .iter()
 684                .filter_map(|(id, shared_worktree)| {
 685                    let worktree = project.worktrees.get(&id)?;
 686                    Some(proto::Worktree {
 687                        id: *id,
 688                        root_name: worktree.root_name.clone(),
 689                        entries: shared_worktree.entries.values().cloned().collect(),
 690                        diagnostic_summaries: shared_worktree
 691                            .diagnostic_summaries
 692                            .values()
 693                            .cloned()
 694                            .collect(),
 695                        visible: worktree.visible,
 696                        scan_id: shared_worktree.scan_id,
 697                    })
 698                })
 699                .collect::<Vec<_>>();
 700
 701            // Add all guests other than the requesting user's own connections as collaborators
 702            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 703                if receipts_with_replica_ids
 704                    .iter()
 705                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 706                {
 707                    collaborators.push(proto::Collaborator {
 708                        peer_id: peer_conn_id.0,
 709                        replica_id: *peer_replica_id as u32,
 710                        user_id: peer_user_id.to_proto(),
 711                    });
 712                }
 713            }
 714
 715            for conn_id in project.connection_ids() {
 716                for (receipt, replica_id) in &receipts_with_replica_ids {
 717                    if conn_id != receipt.sender_id {
 718                        self.peer.send(
 719                            conn_id,
 720                            proto::AddProjectCollaborator {
 721                                project_id,
 722                                collaborator: Some(proto::Collaborator {
 723                                    peer_id: receipt.sender_id.0,
 724                                    replica_id: *replica_id as u32,
 725                                    user_id: guest_user_id.to_proto(),
 726                                }),
 727                            },
 728                        )?;
 729                    }
 730                }
 731            }
 732
 733            for (receipt, replica_id) in receipts_with_replica_ids {
 734                self.peer.respond(
 735                    receipt,
 736                    proto::JoinProjectResponse {
 737                        variant: Some(proto::join_project_response::Variant::Accept(
 738                            proto::join_project_response::Accept {
 739                                worktrees: worktrees.clone(),
 740                                replica_id: replica_id as u32,
 741                                collaborators: collaborators.clone(),
 742                                language_servers: project.language_servers.clone(),
 743                            },
 744                        )),
 745                    },
 746                )?;
 747            }
 748        }
 749
 750        self.update_user_contacts(host_user_id).await?;
 751        Ok(())
 752    }
 753
 754    async fn leave_project(
 755        self: Arc<Server>,
 756        request: TypedEnvelope<proto::LeaveProject>,
 757    ) -> Result<()> {
 758        let sender_id = request.sender_id;
 759        let project_id = request.payload.project_id;
 760        let project;
 761        {
 762            let mut store = self.store_mut().await;
 763            project = store.leave_project(sender_id, project_id)?;
 764            tracing::info!(
 765                project_id,
 766                host_user_id = %project.host_user_id,
 767                host_connection_id = %project.host_connection_id,
 768                "leave project"
 769            );
 770
 771            if project.remove_collaborator {
 772                broadcast(sender_id, project.connection_ids, |conn_id| {
 773                    self.peer.send(
 774                        conn_id,
 775                        proto::RemoveProjectCollaborator {
 776                            project_id,
 777                            peer_id: sender_id.0,
 778                        },
 779                    )
 780                });
 781            }
 782
 783            if let Some(requester_id) = project.cancel_request {
 784                self.peer.send(
 785                    project.host_connection_id,
 786                    proto::JoinProjectRequestCancelled {
 787                        project_id,
 788                        requester_id: requester_id.to_proto(),
 789                    },
 790                )?;
 791            }
 792
 793            if project.unshare {
 794                self.peer.send(
 795                    project.host_connection_id,
 796                    proto::ProjectUnshared { project_id },
 797                )?;
 798            }
 799        }
 800        self.update_user_contacts(project.host_user_id).await?;
 801        Ok(())
 802    }
 803
 804    async fn update_project(
 805        self: Arc<Server>,
 806        request: TypedEnvelope<proto::UpdateProject>,
 807    ) -> Result<()> {
 808        let user_id;
 809        {
 810            let mut state = self.store_mut().await;
 811            user_id = state.user_id_for_connection(request.sender_id)?;
 812            let guest_connection_ids = state
 813                .read_project(request.payload.project_id, request.sender_id)?
 814                .guest_connection_ids();
 815            state.update_project(
 816                request.payload.project_id,
 817                &request.payload.worktrees,
 818                request.sender_id,
 819            )?;
 820            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 821                self.peer
 822                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 823            });
 824        };
 825        self.update_user_contacts(user_id).await?;
 826        Ok(())
 827    }
 828
 829    async fn update_worktree(
 830        self: Arc<Server>,
 831        request: TypedEnvelope<proto::UpdateWorktree>,
 832        response: Response<proto::UpdateWorktree>,
 833    ) -> Result<()> {
 834        let (connection_ids, metadata_changed) = {
 835            let mut store = self.store_mut().await;
 836            let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
 837                request.sender_id,
 838                request.payload.project_id,
 839                request.payload.worktree_id,
 840                &request.payload.root_name,
 841                &request.payload.removed_entries,
 842                &request.payload.updated_entries,
 843                request.payload.scan_id,
 844            )?;
 845            for (extension, count) in extension_counts {
 846                tracing::info!(
 847                    project_id = request.payload.project_id,
 848                    worktree_id = request.payload.worktree_id,
 849                    ?extension,
 850                    %count,
 851                    "worktree updated"
 852                );
 853            }
 854            (connection_ids, metadata_changed)
 855        };
 856
 857        broadcast(request.sender_id, connection_ids, |connection_id| {
 858            self.peer
 859                .forward_send(request.sender_id, connection_id, request.payload.clone())
 860        });
 861        if metadata_changed {
 862            let user_id = self
 863                .store()
 864                .await
 865                .user_id_for_connection(request.sender_id)?;
 866            self.update_user_contacts(user_id).await?;
 867        }
 868        response.send(proto::Ack {})?;
 869        Ok(())
 870    }
 871
 872    async fn update_diagnostic_summary(
 873        self: Arc<Server>,
 874        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 875    ) -> Result<()> {
 876        let summary = request
 877            .payload
 878            .summary
 879            .clone()
 880            .ok_or_else(|| anyhow!("invalid summary"))?;
 881        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 882            request.payload.project_id,
 883            request.payload.worktree_id,
 884            request.sender_id,
 885            summary,
 886        )?;
 887
 888        broadcast(request.sender_id, receiver_ids, |connection_id| {
 889            self.peer
 890                .forward_send(request.sender_id, connection_id, request.payload.clone())
 891        });
 892        Ok(())
 893    }
 894
 895    async fn start_language_server(
 896        self: Arc<Server>,
 897        request: TypedEnvelope<proto::StartLanguageServer>,
 898    ) -> Result<()> {
 899        let receiver_ids = self.store_mut().await.start_language_server(
 900            request.payload.project_id,
 901            request.sender_id,
 902            request
 903                .payload
 904                .server
 905                .clone()
 906                .ok_or_else(|| anyhow!("invalid language server"))?,
 907        )?;
 908        broadcast(request.sender_id, receiver_ids, |connection_id| {
 909            self.peer
 910                .forward_send(request.sender_id, connection_id, request.payload.clone())
 911        });
 912        Ok(())
 913    }
 914
 915    async fn update_language_server(
 916        self: Arc<Server>,
 917        request: TypedEnvelope<proto::UpdateLanguageServer>,
 918    ) -> Result<()> {
 919        let receiver_ids = self
 920            .store()
 921            .await
 922            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 923        broadcast(request.sender_id, receiver_ids, |connection_id| {
 924            self.peer
 925                .forward_send(request.sender_id, connection_id, request.payload.clone())
 926        });
 927        Ok(())
 928    }
 929
 930    async fn forward_project_request<T>(
 931        self: Arc<Server>,
 932        request: TypedEnvelope<T>,
 933        response: Response<T>,
 934    ) -> Result<()>
 935    where
 936        T: EntityMessage + RequestMessage,
 937    {
 938        let host_connection_id = self
 939            .store()
 940            .await
 941            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 942            .host_connection_id;
 943
 944        response.send(
 945            self.peer
 946                .forward_request(request.sender_id, host_connection_id, request.payload)
 947                .await?,
 948        )?;
 949        Ok(())
 950    }
 951
 952    async fn save_buffer(
 953        self: Arc<Server>,
 954        request: TypedEnvelope<proto::SaveBuffer>,
 955        response: Response<proto::SaveBuffer>,
 956    ) -> Result<()> {
 957        let host = self
 958            .store()
 959            .await
 960            .read_project(request.payload.project_id, request.sender_id)?
 961            .host_connection_id;
 962        let response_payload = self
 963            .peer
 964            .forward_request(request.sender_id, host, request.payload.clone())
 965            .await?;
 966
 967        let mut guests = self
 968            .store()
 969            .await
 970            .read_project(request.payload.project_id, request.sender_id)?
 971            .connection_ids();
 972        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 973        broadcast(host, guests, |conn_id| {
 974            self.peer
 975                .forward_send(host, conn_id, response_payload.clone())
 976        });
 977        response.send(response_payload)?;
 978        Ok(())
 979    }
 980
 981    async fn update_buffer(
 982        self: Arc<Server>,
 983        request: TypedEnvelope<proto::UpdateBuffer>,
 984        response: Response<proto::UpdateBuffer>,
 985    ) -> Result<()> {
 986        let receiver_ids = self
 987            .store()
 988            .await
 989            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 990        broadcast(request.sender_id, receiver_ids, |connection_id| {
 991            self.peer
 992                .forward_send(request.sender_id, connection_id, request.payload.clone())
 993        });
 994        response.send(proto::Ack {})?;
 995        Ok(())
 996    }
 997
 998    async fn update_buffer_file(
 999        self: Arc<Server>,
1000        request: TypedEnvelope<proto::UpdateBufferFile>,
1001    ) -> Result<()> {
1002        let receiver_ids = self
1003            .store()
1004            .await
1005            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1006        broadcast(request.sender_id, receiver_ids, |connection_id| {
1007            self.peer
1008                .forward_send(request.sender_id, connection_id, request.payload.clone())
1009        });
1010        Ok(())
1011    }
1012
1013    async fn buffer_reloaded(
1014        self: Arc<Server>,
1015        request: TypedEnvelope<proto::BufferReloaded>,
1016    ) -> Result<()> {
1017        let receiver_ids = self
1018            .store()
1019            .await
1020            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1021        broadcast(request.sender_id, receiver_ids, |connection_id| {
1022            self.peer
1023                .forward_send(request.sender_id, connection_id, request.payload.clone())
1024        });
1025        Ok(())
1026    }
1027
1028    async fn buffer_saved(
1029        self: Arc<Server>,
1030        request: TypedEnvelope<proto::BufferSaved>,
1031    ) -> Result<()> {
1032        let receiver_ids = self
1033            .store()
1034            .await
1035            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1036        broadcast(request.sender_id, receiver_ids, |connection_id| {
1037            self.peer
1038                .forward_send(request.sender_id, connection_id, request.payload.clone())
1039        });
1040        Ok(())
1041    }
1042
1043    async fn follow(
1044        self: Arc<Self>,
1045        request: TypedEnvelope<proto::Follow>,
1046        response: Response<proto::Follow>,
1047    ) -> Result<()> {
1048        let leader_id = ConnectionId(request.payload.leader_id);
1049        let follower_id = request.sender_id;
1050        if !self
1051            .store()
1052            .await
1053            .project_connection_ids(request.payload.project_id, follower_id)?
1054            .contains(&leader_id)
1055        {
1056            Err(anyhow!("no such peer"))?;
1057        }
1058        let mut response_payload = self
1059            .peer
1060            .forward_request(request.sender_id, leader_id, request.payload)
1061            .await?;
1062        response_payload
1063            .views
1064            .retain(|view| view.leader_id != Some(follower_id.0));
1065        response.send(response_payload)?;
1066        Ok(())
1067    }
1068
1069    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1070        let leader_id = ConnectionId(request.payload.leader_id);
1071        if !self
1072            .store()
1073            .await
1074            .project_connection_ids(request.payload.project_id, request.sender_id)?
1075            .contains(&leader_id)
1076        {
1077            Err(anyhow!("no such peer"))?;
1078        }
1079        self.peer
1080            .forward_send(request.sender_id, leader_id, request.payload)?;
1081        Ok(())
1082    }
1083
1084    async fn update_followers(
1085        self: Arc<Self>,
1086        request: TypedEnvelope<proto::UpdateFollowers>,
1087    ) -> Result<()> {
1088        let connection_ids = self
1089            .store()
1090            .await
1091            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1092        let leader_id = request
1093            .payload
1094            .variant
1095            .as_ref()
1096            .and_then(|variant| match variant {
1097                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1098                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1099                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1100            });
1101        for follower_id in &request.payload.follower_ids {
1102            let follower_id = ConnectionId(*follower_id);
1103            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1104                self.peer
1105                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1106            }
1107        }
1108        Ok(())
1109    }
1110
1111    async fn get_channels(
1112        self: Arc<Server>,
1113        request: TypedEnvelope<proto::GetChannels>,
1114        response: Response<proto::GetChannels>,
1115    ) -> Result<()> {
1116        let user_id = self
1117            .store()
1118            .await
1119            .user_id_for_connection(request.sender_id)?;
1120        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1121        response.send(proto::GetChannelsResponse {
1122            channels: channels
1123                .into_iter()
1124                .map(|chan| proto::Channel {
1125                    id: chan.id.to_proto(),
1126                    name: chan.name,
1127                })
1128                .collect(),
1129        })?;
1130        Ok(())
1131    }
1132
1133    async fn get_users(
1134        self: Arc<Server>,
1135        request: TypedEnvelope<proto::GetUsers>,
1136        response: Response<proto::GetUsers>,
1137    ) -> Result<()> {
1138        let user_ids = request
1139            .payload
1140            .user_ids
1141            .into_iter()
1142            .map(UserId::from_proto)
1143            .collect();
1144        let users = self
1145            .app_state
1146            .db
1147            .get_users_by_ids(user_ids)
1148            .await?
1149            .into_iter()
1150            .map(|user| proto::User {
1151                id: user.id.to_proto(),
1152                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1153                github_login: user.github_login,
1154            })
1155            .collect();
1156        response.send(proto::UsersResponse { users })?;
1157        Ok(())
1158    }
1159
1160    async fn fuzzy_search_users(
1161        self: Arc<Server>,
1162        request: TypedEnvelope<proto::FuzzySearchUsers>,
1163        response: Response<proto::FuzzySearchUsers>,
1164    ) -> Result<()> {
1165        let user_id = self
1166            .store()
1167            .await
1168            .user_id_for_connection(request.sender_id)?;
1169        let query = request.payload.query;
1170        let db = &self.app_state.db;
1171        let users = match query.len() {
1172            0 => vec![],
1173            1 | 2 => db
1174                .get_user_by_github_login(&query)
1175                .await?
1176                .into_iter()
1177                .collect(),
1178            _ => db.fuzzy_search_users(&query, 10).await?,
1179        };
1180        let users = users
1181            .into_iter()
1182            .filter(|user| user.id != user_id)
1183            .map(|user| proto::User {
1184                id: user.id.to_proto(),
1185                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1186                github_login: user.github_login,
1187            })
1188            .collect();
1189        response.send(proto::UsersResponse { users })?;
1190        Ok(())
1191    }
1192
1193    async fn request_contact(
1194        self: Arc<Server>,
1195        request: TypedEnvelope<proto::RequestContact>,
1196        response: Response<proto::RequestContact>,
1197    ) -> Result<()> {
1198        let requester_id = self
1199            .store()
1200            .await
1201            .user_id_for_connection(request.sender_id)?;
1202        let responder_id = UserId::from_proto(request.payload.responder_id);
1203        if requester_id == responder_id {
1204            return Err(anyhow!("cannot add yourself as a contact"))?;
1205        }
1206
1207        self.app_state
1208            .db
1209            .send_contact_request(requester_id, responder_id)
1210            .await?;
1211
1212        // Update outgoing contact requests of requester
1213        let mut update = proto::UpdateContacts::default();
1214        update.outgoing_requests.push(responder_id.to_proto());
1215        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1216            self.peer.send(connection_id, update.clone())?;
1217        }
1218
1219        // Update incoming contact requests of responder
1220        let mut update = proto::UpdateContacts::default();
1221        update
1222            .incoming_requests
1223            .push(proto::IncomingContactRequest {
1224                requester_id: requester_id.to_proto(),
1225                should_notify: true,
1226            });
1227        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1228            self.peer.send(connection_id, update.clone())?;
1229        }
1230
1231        response.send(proto::Ack {})?;
1232        Ok(())
1233    }
1234
1235    async fn respond_to_contact_request(
1236        self: Arc<Server>,
1237        request: TypedEnvelope<proto::RespondToContactRequest>,
1238        response: Response<proto::RespondToContactRequest>,
1239    ) -> Result<()> {
1240        let responder_id = self
1241            .store()
1242            .await
1243            .user_id_for_connection(request.sender_id)?;
1244        let requester_id = UserId::from_proto(request.payload.requester_id);
1245        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1246            self.app_state
1247                .db
1248                .dismiss_contact_notification(responder_id, requester_id)
1249                .await?;
1250        } else {
1251            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1252            self.app_state
1253                .db
1254                .respond_to_contact_request(responder_id, requester_id, accept)
1255                .await?;
1256
1257            let store = self.store().await;
1258            // Update responder with new contact
1259            let mut update = proto::UpdateContacts::default();
1260            if accept {
1261                update
1262                    .contacts
1263                    .push(store.contact_for_user(requester_id, false));
1264            }
1265            update
1266                .remove_incoming_requests
1267                .push(requester_id.to_proto());
1268            for connection_id in store.connection_ids_for_user(responder_id) {
1269                self.peer.send(connection_id, update.clone())?;
1270            }
1271
1272            // Update requester with new contact
1273            let mut update = proto::UpdateContacts::default();
1274            if accept {
1275                update
1276                    .contacts
1277                    .push(store.contact_for_user(responder_id, true));
1278            }
1279            update
1280                .remove_outgoing_requests
1281                .push(responder_id.to_proto());
1282            for connection_id in store.connection_ids_for_user(requester_id) {
1283                self.peer.send(connection_id, update.clone())?;
1284            }
1285        }
1286
1287        response.send(proto::Ack {})?;
1288        Ok(())
1289    }
1290
1291    async fn remove_contact(
1292        self: Arc<Server>,
1293        request: TypedEnvelope<proto::RemoveContact>,
1294        response: Response<proto::RemoveContact>,
1295    ) -> Result<()> {
1296        let requester_id = self
1297            .store()
1298            .await
1299            .user_id_for_connection(request.sender_id)?;
1300        let responder_id = UserId::from_proto(request.payload.user_id);
1301        self.app_state
1302            .db
1303            .remove_contact(requester_id, responder_id)
1304            .await?;
1305
1306        // Update outgoing contact requests of requester
1307        let mut update = proto::UpdateContacts::default();
1308        update
1309            .remove_outgoing_requests
1310            .push(responder_id.to_proto());
1311        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1312            self.peer.send(connection_id, update.clone())?;
1313        }
1314
1315        // Update incoming contact requests of responder
1316        let mut update = proto::UpdateContacts::default();
1317        update
1318            .remove_incoming_requests
1319            .push(requester_id.to_proto());
1320        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1321            self.peer.send(connection_id, update.clone())?;
1322        }
1323
1324        response.send(proto::Ack {})?;
1325        Ok(())
1326    }
1327
1328    async fn join_channel(
1329        self: Arc<Self>,
1330        request: TypedEnvelope<proto::JoinChannel>,
1331        response: Response<proto::JoinChannel>,
1332    ) -> Result<()> {
1333        let user_id = self
1334            .store()
1335            .await
1336            .user_id_for_connection(request.sender_id)?;
1337        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1338        if !self
1339            .app_state
1340            .db
1341            .can_user_access_channel(user_id, channel_id)
1342            .await?
1343        {
1344            Err(anyhow!("access denied"))?;
1345        }
1346
1347        self.store_mut()
1348            .await
1349            .join_channel(request.sender_id, channel_id);
1350        let messages = self
1351            .app_state
1352            .db
1353            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1354            .await?
1355            .into_iter()
1356            .map(|msg| proto::ChannelMessage {
1357                id: msg.id.to_proto(),
1358                body: msg.body,
1359                timestamp: msg.sent_at.unix_timestamp() as u64,
1360                sender_id: msg.sender_id.to_proto(),
1361                nonce: Some(msg.nonce.as_u128().into()),
1362            })
1363            .collect::<Vec<_>>();
1364        response.send(proto::JoinChannelResponse {
1365            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1366            messages,
1367        })?;
1368        Ok(())
1369    }
1370
1371    async fn leave_channel(
1372        self: Arc<Self>,
1373        request: TypedEnvelope<proto::LeaveChannel>,
1374    ) -> Result<()> {
1375        let user_id = self
1376            .store()
1377            .await
1378            .user_id_for_connection(request.sender_id)?;
1379        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1380        if !self
1381            .app_state
1382            .db
1383            .can_user_access_channel(user_id, channel_id)
1384            .await?
1385        {
1386            Err(anyhow!("access denied"))?;
1387        }
1388
1389        self.store_mut()
1390            .await
1391            .leave_channel(request.sender_id, channel_id);
1392
1393        Ok(())
1394    }
1395
1396    async fn send_channel_message(
1397        self: Arc<Self>,
1398        request: TypedEnvelope<proto::SendChannelMessage>,
1399        response: Response<proto::SendChannelMessage>,
1400    ) -> Result<()> {
1401        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1402        let user_id;
1403        let connection_ids;
1404        {
1405            let state = self.store().await;
1406            user_id = state.user_id_for_connection(request.sender_id)?;
1407            connection_ids = state.channel_connection_ids(channel_id)?;
1408        }
1409
1410        // Validate the message body.
1411        let body = request.payload.body.trim().to_string();
1412        if body.len() > MAX_MESSAGE_LEN {
1413            return Err(anyhow!("message is too long"))?;
1414        }
1415        if body.is_empty() {
1416            return Err(anyhow!("message can't be blank"))?;
1417        }
1418
1419        let timestamp = OffsetDateTime::now_utc();
1420        let nonce = request
1421            .payload
1422            .nonce
1423            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1424
1425        let message_id = self
1426            .app_state
1427            .db
1428            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1429            .await?
1430            .to_proto();
1431        let message = proto::ChannelMessage {
1432            sender_id: user_id.to_proto(),
1433            id: message_id,
1434            body,
1435            timestamp: timestamp.unix_timestamp() as u64,
1436            nonce: Some(nonce),
1437        };
1438        broadcast(request.sender_id, connection_ids, |conn_id| {
1439            self.peer.send(
1440                conn_id,
1441                proto::ChannelMessageSent {
1442                    channel_id: channel_id.to_proto(),
1443                    message: Some(message.clone()),
1444                },
1445            )
1446        });
1447        response.send(proto::SendChannelMessageResponse {
1448            message: Some(message),
1449        })?;
1450        Ok(())
1451    }
1452
1453    async fn get_channel_messages(
1454        self: Arc<Self>,
1455        request: TypedEnvelope<proto::GetChannelMessages>,
1456        response: Response<proto::GetChannelMessages>,
1457    ) -> Result<()> {
1458        let user_id = self
1459            .store()
1460            .await
1461            .user_id_for_connection(request.sender_id)?;
1462        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1463        if !self
1464            .app_state
1465            .db
1466            .can_user_access_channel(user_id, channel_id)
1467            .await?
1468        {
1469            Err(anyhow!("access denied"))?;
1470        }
1471
1472        let messages = self
1473            .app_state
1474            .db
1475            .get_channel_messages(
1476                channel_id,
1477                MESSAGE_COUNT_PER_PAGE,
1478                Some(MessageId::from_proto(request.payload.before_message_id)),
1479            )
1480            .await?
1481            .into_iter()
1482            .map(|msg| proto::ChannelMessage {
1483                id: msg.id.to_proto(),
1484                body: msg.body,
1485                timestamp: msg.sent_at.unix_timestamp() as u64,
1486                sender_id: msg.sender_id.to_proto(),
1487                nonce: Some(msg.nonce.as_u128().into()),
1488            })
1489            .collect::<Vec<_>>();
1490        response.send(proto::GetChannelMessagesResponse {
1491            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1492            messages,
1493        })?;
1494        Ok(())
1495    }
1496
1497    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1498        #[cfg(test)]
1499        tokio::task::yield_now().await;
1500        let guard = self.store.read().await;
1501        #[cfg(test)]
1502        tokio::task::yield_now().await;
1503        StoreReadGuard {
1504            guard,
1505            _not_send: PhantomData,
1506        }
1507    }
1508
1509    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1510        #[cfg(test)]
1511        tokio::task::yield_now().await;
1512        let guard = self.store.write().await;
1513        #[cfg(test)]
1514        tokio::task::yield_now().await;
1515        StoreWriteGuard {
1516            guard,
1517            _not_send: PhantomData,
1518        }
1519    }
1520
1521    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1522        ServerSnapshot {
1523            store: self.store.read().await,
1524            peer: &self.peer,
1525        }
1526    }
1527}
1528
1529impl<'a> Deref for StoreReadGuard<'a> {
1530    type Target = Store;
1531
1532    fn deref(&self) -> &Self::Target {
1533        &*self.guard
1534    }
1535}
1536
1537impl<'a> Deref for StoreWriteGuard<'a> {
1538    type Target = Store;
1539
1540    fn deref(&self) -> &Self::Target {
1541        &*self.guard
1542    }
1543}
1544
1545impl<'a> DerefMut for StoreWriteGuard<'a> {
1546    fn deref_mut(&mut self) -> &mut Self::Target {
1547        &mut *self.guard
1548    }
1549}
1550
1551impl<'a> Drop for StoreWriteGuard<'a> {
1552    fn drop(&mut self) {
1553        #[cfg(test)]
1554        self.check_invariants();
1555
1556        let metrics = self.metrics();
1557        tracing::info!(
1558            connections = metrics.connections,
1559            registered_projects = metrics.registered_projects,
1560            shared_projects = metrics.shared_projects,
1561            "metrics"
1562        );
1563    }
1564}
1565
1566impl Executor for RealExecutor {
1567    type Sleep = Sleep;
1568
1569    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1570        tokio::task::spawn(future);
1571    }
1572
1573    fn sleep(&self, duration: Duration) -> Self::Sleep {
1574        tokio::time::sleep(duration)
1575    }
1576}
1577
1578fn broadcast<F>(
1579    sender_id: ConnectionId,
1580    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1581    mut f: F,
1582) where
1583    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1584{
1585    for receiver_id in receiver_ids {
1586        if receiver_id != sender_id {
1587            f(receiver_id).trace_err();
1588        }
1589    }
1590}
1591
1592lazy_static! {
1593    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1594}
1595
1596pub struct ProtocolVersion(u32);
1597
1598impl Header for ProtocolVersion {
1599    fn name() -> &'static HeaderName {
1600        &ZED_PROTOCOL_VERSION
1601    }
1602
1603    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1604    where
1605        Self: Sized,
1606        I: Iterator<Item = &'i axum::http::HeaderValue>,
1607    {
1608        let version = values
1609            .next()
1610            .ok_or_else(|| axum::headers::Error::invalid())?
1611            .to_str()
1612            .map_err(|_| axum::headers::Error::invalid())?
1613            .parse()
1614            .map_err(|_| axum::headers::Error::invalid())?;
1615        Ok(Self(version))
1616    }
1617
1618    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1619        values.extend([self.0.to_string().parse().unwrap()]);
1620    }
1621}
1622
1623pub fn routes(server: Arc<Server>) -> Router<Body> {
1624    Router::new()
1625        .route("/rpc", get(handle_websocket_request))
1626        .layer(
1627            ServiceBuilder::new()
1628                .layer(Extension(server.app_state.clone()))
1629                .layer(middleware::from_fn(auth::validate_header))
1630                .layer(Extension(server)),
1631        )
1632}
1633
1634pub async fn handle_websocket_request(
1635    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1636    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1637    Extension(server): Extension<Arc<Server>>,
1638    Extension(user): Extension<User>,
1639    ws: WebSocketUpgrade,
1640) -> axum::response::Response {
1641    if protocol_version != rpc::PROTOCOL_VERSION {
1642        return (
1643            StatusCode::UPGRADE_REQUIRED,
1644            "client must be upgraded".to_string(),
1645        )
1646            .into_response();
1647    }
1648    let socket_address = socket_address.to_string();
1649    ws.on_upgrade(move |socket| {
1650        use util::ResultExt;
1651        let socket = socket
1652            .map_ok(to_tungstenite_message)
1653            .err_into()
1654            .with(|message| async move { Ok(to_axum_message(message)) });
1655        let connection = Connection::new(Box::pin(socket));
1656        async move {
1657            server
1658                .handle_connection(connection, socket_address, user, None, RealExecutor)
1659                .await
1660                .log_err();
1661        }
1662    })
1663}
1664
1665fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1666    match message {
1667        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1668        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1669        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1670        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1671        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1672            code: frame.code.into(),
1673            reason: frame.reason,
1674        })),
1675    }
1676}
1677
1678fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1679    match message {
1680        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1681        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1682        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1683        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1684        AxumMessage::Close(frame) => {
1685            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1686                code: frame.code.into(),
1687                reason: frame.reason,
1688            }))
1689        }
1690    }
1691}
1692
1693pub trait ResultExt {
1694    type Ok;
1695
1696    fn trace_err(self) -> Option<Self::Ok>;
1697}
1698
1699impl<T, E> ResultExt for Result<T, E>
1700where
1701    E: std::fmt::Debug,
1702{
1703    type Ok = T;
1704
1705    fn trace_err(self) -> Option<T> {
1706        match self {
1707            Ok(value) => Some(value),
1708            Err(error) => {
1709                tracing::error!("{:?}", error);
1710                None
1711            }
1712        }
1713    }
1714}