rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use rpc::{
  33    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  34    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  35};
  36use serde::{Serialize, Serializer};
  37use std::{
  38    any::TypeId,
  39    future::Future,
  40    marker::PhantomData,
  41    net::SocketAddr,
  42    ops::{Deref, DerefMut},
  43    rc::Rc,
  44    sync::{
  45        atomic::{AtomicBool, Ordering::SeqCst},
  46        Arc,
  47    },
  48    time::Duration,
  49};
  50use time::OffsetDateTime;
  51use tokio::{
  52    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  53    time::Sleep,
  54};
  55use tower::ServiceBuilder;
  56use tracing::{info_span, instrument, Instrument};
  57
  58pub use store::{Store, Worktree};
  59
  60type MessageHandler =
  61    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  62
  63struct Response<R> {
  64    server: Arc<Server>,
  65    receipt: Receipt<R>,
  66    responded: Arc<AtomicBool>,
  67}
  68
  69impl<R: RequestMessage> Response<R> {
  70    fn send(self, payload: R::Response) -> Result<()> {
  71        self.responded.store(true, SeqCst);
  72        self.server.peer.respond(self.receipt, payload)?;
  73        Ok(())
  74    }
  75
  76    fn into_receipt(self) -> Receipt<R> {
  77        self.responded.store(true, SeqCst);
  78        self.receipt
  79    }
  80}
  81
  82pub struct Server {
  83    peer: Arc<Peer>,
  84    pub(crate) store: RwLock<Store>,
  85    app_state: Arc<AppState>,
  86    handlers: HashMap<TypeId, MessageHandler>,
  87    notifications: Option<mpsc::UnboundedSender<()>>,
  88}
  89
  90pub trait Executor: Send + Clone {
  91    type Sleep: Send + Future;
  92    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
  93    fn sleep(&self, duration: Duration) -> Self::Sleep;
  94}
  95
  96#[derive(Clone)]
  97pub struct RealExecutor;
  98
  99const MESSAGE_COUNT_PER_PAGE: usize = 100;
 100const MAX_MESSAGE_LEN: usize = 1024;
 101
 102struct StoreReadGuard<'a> {
 103    guard: RwLockReadGuard<'a, Store>,
 104    _not_send: PhantomData<Rc<()>>,
 105}
 106
 107struct StoreWriteGuard<'a> {
 108    guard: RwLockWriteGuard<'a, Store>,
 109    _not_send: PhantomData<Rc<()>>,
 110}
 111
 112#[derive(Serialize)]
 113pub struct ServerSnapshot<'a> {
 114    peer: &'a Peer,
 115    #[serde(serialize_with = "serialize_deref")]
 116    store: RwLockReadGuard<'a, Store>,
 117}
 118
 119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 120where
 121    S: Serializer,
 122    T: Deref<Target = U>,
 123    U: Serialize,
 124{
 125    Serialize::serialize(value.deref(), serializer)
 126}
 127
 128impl Server {
 129    pub fn new(
 130        app_state: Arc<AppState>,
 131        notifications: Option<mpsc::UnboundedSender<()>>,
 132    ) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138            notifications,
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::register_project)
 144            .add_request_handler(Server::unregister_project)
 145            .add_request_handler(Server::join_project)
 146            .add_message_handler(Server::leave_project)
 147            .add_message_handler(Server::respond_to_join_project_request)
 148            .add_message_handler(Server::update_project)
 149            .add_request_handler(Server::update_worktree)
 150            .add_message_handler(Server::start_language_server)
 151            .add_message_handler(Server::update_language_server)
 152            .add_message_handler(Server::update_diagnostic_summary)
 153            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 154            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 155            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 156            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 157            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 158            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 159            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 160            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 161            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 163            .add_request_handler(
 164                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 165            )
 166            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 167            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 168            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 169            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 170            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 171            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 172            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 173            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 174            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 175            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 176            .add_request_handler(Server::update_buffer)
 177            .add_message_handler(Server::update_buffer_file)
 178            .add_message_handler(Server::buffer_reloaded)
 179            .add_message_handler(Server::buffer_saved)
 180            .add_request_handler(Server::save_buffer)
 181            .add_request_handler(Server::get_channels)
 182            .add_request_handler(Server::get_users)
 183            .add_request_handler(Server::fuzzy_search_users)
 184            .add_request_handler(Server::request_contact)
 185            .add_request_handler(Server::remove_contact)
 186            .add_request_handler(Server::respond_to_contact_request)
 187            .add_request_handler(Server::join_channel)
 188            .add_message_handler(Server::leave_channel)
 189            .add_request_handler(Server::send_channel_message)
 190            .add_request_handler(Server::follow)
 191            .add_message_handler(Server::unfollow)
 192            .add_message_handler(Server::update_followers)
 193            .add_request_handler(Server::get_channel_messages);
 194
 195        Arc::new(server)
 196    }
 197
 198    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 199    where
 200        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 201        Fut: 'static + Send + Future<Output = Result<()>>,
 202        M: EnvelopedMessage,
 203    {
 204        let prev_handler = self.handlers.insert(
 205            TypeId::of::<M>(),
 206            Box::new(move |server, envelope| {
 207                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 208                let span = info_span!(
 209                    "handle message",
 210                    payload_type = envelope.payload_type_name(),
 211                    payload = format!("{:?}", envelope.payload).as_str(),
 212                );
 213                let future = (handler)(server, *envelope);
 214                async move {
 215                    if let Err(error) = future.await {
 216                        tracing::error!(%error, "error handling message");
 217                    }
 218                }
 219                .instrument(span)
 220                .boxed()
 221            }),
 222        );
 223        if prev_handler.is_some() {
 224            panic!("registered a handler for the same message twice");
 225        }
 226        self
 227    }
 228
 229    /// Handle a request while holding a lock to the store. This is useful when we're registering
 230    /// a connection but we want to respond on the connection before anybody else can send on it.
 231    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 232    where
 233        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 234        Fut: Send + Future<Output = Result<()>>,
 235        M: RequestMessage,
 236    {
 237        let handler = Arc::new(handler);
 238        self.add_message_handler(move |server, envelope| {
 239            let receipt = envelope.receipt();
 240            let handler = handler.clone();
 241            async move {
 242                let responded = Arc::new(AtomicBool::default());
 243                let response = Response {
 244                    server: server.clone(),
 245                    responded: responded.clone(),
 246                    receipt: envelope.receipt(),
 247                };
 248                match (handler)(server.clone(), envelope, response).await {
 249                    Ok(()) => {
 250                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 251                            Ok(())
 252                        } else {
 253                            Err(anyhow!("handler did not send a response"))?
 254                        }
 255                    }
 256                    Err(error) => {
 257                        server.peer.respond_with_error(
 258                            receipt,
 259                            proto::Error {
 260                                message: error.to_string(),
 261                            },
 262                        )?;
 263                        Err(error)
 264                    }
 265                }
 266            }
 267        })
 268    }
 269
 270    pub fn handle_connection<E: Executor>(
 271        self: &Arc<Self>,
 272        connection: Connection,
 273        address: String,
 274        user: User,
 275        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 276        executor: E,
 277    ) -> impl Future<Output = Result<()>> {
 278        let mut this = self.clone();
 279        let user_id = user.id;
 280        let login = user.github_login;
 281        let span = info_span!("handle connection", %user_id, %login, %address);
 282        async move {
 283            let (connection_id, handle_io, mut incoming_rx) = this
 284                .peer
 285                .add_connection(connection, {
 286                    let executor = executor.clone();
 287                    move |duration| {
 288                        let timer = executor.sleep(duration);
 289                        async move {
 290                            timer.await;
 291                        }
 292                    }
 293                })
 294                .await;
 295
 296            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 297
 298            if let Some(send_connection_id) = send_connection_id.as_mut() {
 299                let _ = send_connection_id.send(connection_id).await;
 300            }
 301
 302            if !user.connected_once {
 303                this.peer.send(connection_id, proto::ShowContacts {})?;
 304                this.app_state.db.set_user_connected_once(user_id, true).await?;
 305            }
 306
 307            let (contacts, invite_code) = future::try_join(
 308                this.app_state.db.get_contacts(user_id),
 309                this.app_state.db.get_invite_code_for_user(user_id)
 310            ).await?;
 311
 312            {
 313                let mut store = this.store_mut().await;
 314                store.add_connection(connection_id, user_id);
 315                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 316
 317                if let Some((code, count)) = invite_code {
 318                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 319                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 320                        count,
 321                    })?;
 322                }
 323            }
 324            this.update_user_contacts(user_id).await?;
 325
 326            let handle_io = handle_io.fuse();
 327            futures::pin_mut!(handle_io);
 328            loop {
 329                let next_message = incoming_rx.next().fuse();
 330                futures::pin_mut!(next_message);
 331                futures::select_biased! {
 332                    result = handle_io => {
 333                        if let Err(error) = result {
 334                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 335                        }
 336                        break;
 337                    }
 338                    message = next_message => {
 339                        if let Some(message) = message {
 340                            let type_name = message.payload_type_name();
 341                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 342                            async {
 343                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 344                                    let notifications = this.notifications.clone();
 345                                    let is_background = message.is_background();
 346                                    let handle_message = (handler)(this.clone(), message);
 347                                    let handle_message = async move {
 348                                        handle_message.await;
 349                                        if let Some(mut notifications) = notifications {
 350                                            let _ = notifications.send(()).await;
 351                                        }
 352                                    };
 353                                    if is_background {
 354                                        executor.spawn_detached(handle_message);
 355                                    } else {
 356                                        handle_message.await;
 357                                    }
 358                                } else {
 359                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 360                                }
 361                            }.instrument(span).await;
 362                        } else {
 363                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 364                            break;
 365                        }
 366                    }
 367                }
 368            }
 369
 370            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 371            if let Err(error) = this.sign_out(connection_id).await {
 372                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 373            }
 374
 375            Ok(())
 376        }.instrument(span)
 377    }
 378
 379    #[instrument(skip(self), err)]
 380    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 381        self.peer.disconnect(connection_id);
 382
 383        let removed_user_id = {
 384            let mut store = self.store_mut().await;
 385            let removed_connection = store.remove_connection(connection_id)?;
 386
 387            for (project_id, project) in removed_connection.hosted_projects {
 388                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 389                    self.peer
 390                        .send(conn_id, proto::UnregisterProject { project_id })
 391                });
 392
 393                for (_, receipts) in project.join_requests {
 394                    for receipt in receipts {
 395                        self.peer.respond(
 396                            receipt,
 397                            proto::JoinProjectResponse {
 398                                variant: Some(proto::join_project_response::Variant::Decline(
 399                                    proto::join_project_response::Decline {
 400                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 401                                    },
 402                                )),
 403                            },
 404                        )?;
 405                    }
 406                }
 407            }
 408
 409            for project_id in removed_connection.guest_project_ids {
 410                if let Some(project) = store.project(project_id).trace_err() {
 411                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 412                        self.peer.send(
 413                            conn_id,
 414                            proto::RemoveProjectCollaborator {
 415                                project_id,
 416                                peer_id: connection_id.0,
 417                            },
 418                        )
 419                    });
 420                    if project.guests.is_empty() {
 421                        self.peer
 422                            .send(
 423                                project.host_connection_id,
 424                                proto::ProjectUnshared { project_id },
 425                            )
 426                            .trace_err();
 427                    }
 428                }
 429            }
 430
 431            removed_connection.user_id
 432        };
 433
 434        self.update_user_contacts(removed_user_id).await?;
 435
 436        Ok(())
 437    }
 438
 439    pub async fn invite_code_redeemed(
 440        self: &Arc<Self>,
 441        code: &str,
 442        invitee_id: UserId,
 443    ) -> Result<()> {
 444        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 445        let store = self.store().await;
 446        let invitee_contact = store.contact_for_user(invitee_id, true);
 447        for connection_id in store.connection_ids_for_user(user.id) {
 448            self.peer.send(
 449                connection_id,
 450                proto::UpdateContacts {
 451                    contacts: vec![invitee_contact.clone()],
 452                    ..Default::default()
 453                },
 454            )?;
 455            self.peer.send(
 456                connection_id,
 457                proto::UpdateInviteInfo {
 458                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 459                    count: user.invite_count as u32,
 460                },
 461            )?;
 462        }
 463        Ok(())
 464    }
 465
 466    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 467        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 468            if let Some(invite_code) = &user.invite_code {
 469                let store = self.store().await;
 470                for connection_id in store.connection_ids_for_user(user_id) {
 471                    self.peer.send(
 472                        connection_id,
 473                        proto::UpdateInviteInfo {
 474                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 475                            count: user.invite_count as u32,
 476                        },
 477                    )?;
 478                }
 479            }
 480        }
 481        Ok(())
 482    }
 483
 484    async fn ping(
 485        self: Arc<Server>,
 486        _: TypedEnvelope<proto::Ping>,
 487        response: Response<proto::Ping>,
 488    ) -> Result<()> {
 489        response.send(proto::Ack {})?;
 490        Ok(())
 491    }
 492
 493    async fn register_project(
 494        self: Arc<Server>,
 495        request: TypedEnvelope<proto::RegisterProject>,
 496        response: Response<proto::RegisterProject>,
 497    ) -> Result<()> {
 498        let project_id;
 499        {
 500            let mut state = self.store_mut().await;
 501            let user_id = state.user_id_for_connection(request.sender_id)?;
 502            project_id = state.register_project(request.sender_id, user_id);
 503        };
 504
 505        response.send(proto::RegisterProjectResponse { project_id })?;
 506
 507        Ok(())
 508    }
 509
 510    async fn unregister_project(
 511        self: Arc<Server>,
 512        request: TypedEnvelope<proto::UnregisterProject>,
 513        response: Response<proto::UnregisterProject>,
 514    ) -> Result<()> {
 515        let (user_id, project) = {
 516            let mut state = self.store_mut().await;
 517            let project =
 518                state.unregister_project(request.payload.project_id, request.sender_id)?;
 519            (state.user_id_for_connection(request.sender_id)?, project)
 520        };
 521
 522        broadcast(
 523            request.sender_id,
 524            project.guests.keys().copied(),
 525            |conn_id| {
 526                self.peer.send(
 527                    conn_id,
 528                    proto::UnregisterProject {
 529                        project_id: request.payload.project_id,
 530                    },
 531                )
 532            },
 533        );
 534        for (_, receipts) in project.join_requests {
 535            for receipt in receipts {
 536                self.peer.respond(
 537                    receipt,
 538                    proto::JoinProjectResponse {
 539                        variant: Some(proto::join_project_response::Variant::Decline(
 540                            proto::join_project_response::Decline {
 541                                reason: proto::join_project_response::decline::Reason::Closed
 542                                    as i32,
 543                            },
 544                        )),
 545                    },
 546                )?;
 547            }
 548        }
 549
 550        // Send out the `UpdateContacts` message before responding to the unregister
 551        // request. This way, when the project's host can keep track of the project's
 552        // remote id until after they've received the `UpdateContacts` message for
 553        // themself.
 554        self.update_user_contacts(user_id).await?;
 555        response.send(proto::Ack {})?;
 556
 557        Ok(())
 558    }
 559
 560    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 561        let contacts = self.app_state.db.get_contacts(user_id).await?;
 562        let store = self.store().await;
 563        let updated_contact = store.contact_for_user(user_id, false);
 564        for contact in contacts {
 565            if let db::Contact::Accepted {
 566                user_id: contact_user_id,
 567                ..
 568            } = contact
 569            {
 570                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 571                    self.peer
 572                        .send(
 573                            contact_conn_id,
 574                            proto::UpdateContacts {
 575                                contacts: vec![updated_contact.clone()],
 576                                remove_contacts: Default::default(),
 577                                incoming_requests: Default::default(),
 578                                remove_incoming_requests: Default::default(),
 579                                outgoing_requests: Default::default(),
 580                                remove_outgoing_requests: Default::default(),
 581                            },
 582                        )
 583                        .trace_err();
 584                }
 585            }
 586        }
 587        Ok(())
 588    }
 589
 590    async fn join_project(
 591        self: Arc<Server>,
 592        request: TypedEnvelope<proto::JoinProject>,
 593        response: Response<proto::JoinProject>,
 594    ) -> Result<()> {
 595        let project_id = request.payload.project_id;
 596
 597        let host_user_id;
 598        let guest_user_id;
 599        let host_connection_id;
 600        {
 601            let state = self.store().await;
 602            let project = state.project(project_id)?;
 603            host_user_id = project.host_user_id;
 604            host_connection_id = project.host_connection_id;
 605            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 606        };
 607
 608        let has_contact = self
 609            .app_state
 610            .db
 611            .has_contact(guest_user_id, host_user_id)
 612            .await?;
 613        if !has_contact {
 614            return Err(anyhow!("no such project"))?;
 615        }
 616
 617        self.store_mut().await.request_join_project(
 618            guest_user_id,
 619            project_id,
 620            response.into_receipt(),
 621        )?;
 622        self.peer.send(
 623            host_connection_id,
 624            proto::RequestJoinProject {
 625                project_id,
 626                requester_id: guest_user_id.to_proto(),
 627            },
 628        )?;
 629        Ok(())
 630    }
 631
 632    async fn respond_to_join_project_request(
 633        self: Arc<Server>,
 634        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 635    ) -> Result<()> {
 636        let host_user_id;
 637
 638        {
 639            let mut state = self.store_mut().await;
 640            let project_id = request.payload.project_id;
 641            let project = state.project(project_id)?;
 642            if project.host_connection_id != request.sender_id {
 643                Err(anyhow!("no such connection"))?;
 644            }
 645
 646            host_user_id = project.host_user_id;
 647            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 648
 649            if !request.payload.allow {
 650                let receipts = state
 651                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 652                    .ok_or_else(|| anyhow!("no such request"))?;
 653                for receipt in receipts {
 654                    self.peer.respond(
 655                        receipt,
 656                        proto::JoinProjectResponse {
 657                            variant: Some(proto::join_project_response::Variant::Decline(
 658                                proto::join_project_response::Decline {
 659                                    reason: proto::join_project_response::decline::Reason::Declined
 660                                        as i32,
 661                                },
 662                            )),
 663                        },
 664                    )?;
 665                }
 666                return Ok(());
 667            }
 668
 669            let (receipts_with_replica_ids, project) = state
 670                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 671                .ok_or_else(|| anyhow!("no such request"))?;
 672
 673            let peer_count = project.guests.len();
 674            let mut collaborators = Vec::with_capacity(peer_count);
 675            collaborators.push(proto::Collaborator {
 676                peer_id: project.host_connection_id.0,
 677                replica_id: 0,
 678                user_id: project.host_user_id.to_proto(),
 679            });
 680            let worktrees = project
 681                .worktrees
 682                .iter()
 683                .filter_map(|(id, shared_worktree)| {
 684                    let worktree = project.worktrees.get(&id)?;
 685                    Some(proto::Worktree {
 686                        id: *id,
 687                        root_name: worktree.root_name.clone(),
 688                        entries: shared_worktree.entries.values().cloned().collect(),
 689                        diagnostic_summaries: shared_worktree
 690                            .diagnostic_summaries
 691                            .values()
 692                            .cloned()
 693                            .collect(),
 694                        visible: worktree.visible,
 695                        scan_id: shared_worktree.scan_id,
 696                    })
 697                })
 698                .collect::<Vec<_>>();
 699
 700            // Add all guests other than the requesting user's own connections as collaborators
 701            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 702                if receipts_with_replica_ids
 703                    .iter()
 704                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 705                {
 706                    collaborators.push(proto::Collaborator {
 707                        peer_id: peer_conn_id.0,
 708                        replica_id: *peer_replica_id as u32,
 709                        user_id: peer_user_id.to_proto(),
 710                    });
 711                }
 712            }
 713
 714            for conn_id in project.connection_ids() {
 715                for (receipt, replica_id) in &receipts_with_replica_ids {
 716                    if conn_id != receipt.sender_id {
 717                        self.peer.send(
 718                            conn_id,
 719                            proto::AddProjectCollaborator {
 720                                project_id,
 721                                collaborator: Some(proto::Collaborator {
 722                                    peer_id: receipt.sender_id.0,
 723                                    replica_id: *replica_id as u32,
 724                                    user_id: guest_user_id.to_proto(),
 725                                }),
 726                            },
 727                        )?;
 728                    }
 729                }
 730            }
 731
 732            for (receipt, replica_id) in receipts_with_replica_ids {
 733                self.peer.respond(
 734                    receipt,
 735                    proto::JoinProjectResponse {
 736                        variant: Some(proto::join_project_response::Variant::Accept(
 737                            proto::join_project_response::Accept {
 738                                worktrees: worktrees.clone(),
 739                                replica_id: replica_id as u32,
 740                                collaborators: collaborators.clone(),
 741                                language_servers: project.language_servers.clone(),
 742                            },
 743                        )),
 744                    },
 745                )?;
 746            }
 747        }
 748
 749        self.update_user_contacts(host_user_id).await?;
 750        Ok(())
 751    }
 752
 753    async fn leave_project(
 754        self: Arc<Server>,
 755        request: TypedEnvelope<proto::LeaveProject>,
 756    ) -> Result<()> {
 757        let sender_id = request.sender_id;
 758        let project_id = request.payload.project_id;
 759        let project;
 760        {
 761            let mut store = self.store_mut().await;
 762            project = store.leave_project(sender_id, project_id)?;
 763
 764            if project.remove_collaborator {
 765                broadcast(sender_id, project.connection_ids, |conn_id| {
 766                    self.peer.send(
 767                        conn_id,
 768                        proto::RemoveProjectCollaborator {
 769                            project_id,
 770                            peer_id: sender_id.0,
 771                        },
 772                    )
 773                });
 774            }
 775
 776            if let Some(requester_id) = project.cancel_request {
 777                self.peer.send(
 778                    project.host_connection_id,
 779                    proto::JoinProjectRequestCancelled {
 780                        project_id,
 781                        requester_id: requester_id.to_proto(),
 782                    },
 783                )?;
 784            }
 785
 786            if project.unshare {
 787                self.peer.send(
 788                    project.host_connection_id,
 789                    proto::ProjectUnshared { project_id },
 790                )?;
 791            }
 792        }
 793        self.update_user_contacts(project.host_user_id).await?;
 794        Ok(())
 795    }
 796
 797    async fn update_project(
 798        self: Arc<Server>,
 799        request: TypedEnvelope<proto::UpdateProject>,
 800    ) -> Result<()> {
 801        let user_id;
 802        {
 803            let mut state = self.store_mut().await;
 804            user_id = state.user_id_for_connection(request.sender_id)?;
 805            let guest_connection_ids = state
 806                .read_project(request.payload.project_id, request.sender_id)?
 807                .guest_connection_ids();
 808            state.update_project(
 809                request.payload.project_id,
 810                &request.payload.worktrees,
 811                request.sender_id,
 812            )?;
 813            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 814                self.peer
 815                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 816            });
 817        };
 818        self.update_user_contacts(user_id).await?;
 819        Ok(())
 820    }
 821
 822    async fn update_worktree(
 823        self: Arc<Server>,
 824        request: TypedEnvelope<proto::UpdateWorktree>,
 825        response: Response<proto::UpdateWorktree>,
 826    ) -> Result<()> {
 827        let (connection_ids, metadata_changed) = self.store_mut().await.update_worktree(
 828            request.sender_id,
 829            request.payload.project_id,
 830            request.payload.worktree_id,
 831            &request.payload.root_name,
 832            &request.payload.removed_entries,
 833            &request.payload.updated_entries,
 834            request.payload.scan_id,
 835        )?;
 836
 837        broadcast(request.sender_id, connection_ids, |connection_id| {
 838            self.peer
 839                .forward_send(request.sender_id, connection_id, request.payload.clone())
 840        });
 841        if metadata_changed {
 842            let user_id = self
 843                .store()
 844                .await
 845                .user_id_for_connection(request.sender_id)?;
 846            self.update_user_contacts(user_id).await?;
 847        }
 848        response.send(proto::Ack {})?;
 849        Ok(())
 850    }
 851
 852    async fn update_diagnostic_summary(
 853        self: Arc<Server>,
 854        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 855    ) -> Result<()> {
 856        let summary = request
 857            .payload
 858            .summary
 859            .clone()
 860            .ok_or_else(|| anyhow!("invalid summary"))?;
 861        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 862            request.payload.project_id,
 863            request.payload.worktree_id,
 864            request.sender_id,
 865            summary,
 866        )?;
 867
 868        broadcast(request.sender_id, receiver_ids, |connection_id| {
 869            self.peer
 870                .forward_send(request.sender_id, connection_id, request.payload.clone())
 871        });
 872        Ok(())
 873    }
 874
 875    async fn start_language_server(
 876        self: Arc<Server>,
 877        request: TypedEnvelope<proto::StartLanguageServer>,
 878    ) -> Result<()> {
 879        let receiver_ids = self.store_mut().await.start_language_server(
 880            request.payload.project_id,
 881            request.sender_id,
 882            request
 883                .payload
 884                .server
 885                .clone()
 886                .ok_or_else(|| anyhow!("invalid language server"))?,
 887        )?;
 888        broadcast(request.sender_id, receiver_ids, |connection_id| {
 889            self.peer
 890                .forward_send(request.sender_id, connection_id, request.payload.clone())
 891        });
 892        Ok(())
 893    }
 894
 895    async fn update_language_server(
 896        self: Arc<Server>,
 897        request: TypedEnvelope<proto::UpdateLanguageServer>,
 898    ) -> Result<()> {
 899        let receiver_ids = self
 900            .store()
 901            .await
 902            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 903        broadcast(request.sender_id, receiver_ids, |connection_id| {
 904            self.peer
 905                .forward_send(request.sender_id, connection_id, request.payload.clone())
 906        });
 907        Ok(())
 908    }
 909
 910    async fn forward_project_request<T>(
 911        self: Arc<Server>,
 912        request: TypedEnvelope<T>,
 913        response: Response<T>,
 914    ) -> Result<()>
 915    where
 916        T: EntityMessage + RequestMessage,
 917    {
 918        let host_connection_id = self
 919            .store()
 920            .await
 921            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 922            .host_connection_id;
 923
 924        response.send(
 925            self.peer
 926                .forward_request(request.sender_id, host_connection_id, request.payload)
 927                .await?,
 928        )?;
 929        Ok(())
 930    }
 931
 932    async fn save_buffer(
 933        self: Arc<Server>,
 934        request: TypedEnvelope<proto::SaveBuffer>,
 935        response: Response<proto::SaveBuffer>,
 936    ) -> Result<()> {
 937        let host = self
 938            .store()
 939            .await
 940            .read_project(request.payload.project_id, request.sender_id)?
 941            .host_connection_id;
 942        let response_payload = self
 943            .peer
 944            .forward_request(request.sender_id, host, request.payload.clone())
 945            .await?;
 946
 947        let mut guests = self
 948            .store()
 949            .await
 950            .read_project(request.payload.project_id, request.sender_id)?
 951            .connection_ids();
 952        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 953        broadcast(host, guests, |conn_id| {
 954            self.peer
 955                .forward_send(host, conn_id, response_payload.clone())
 956        });
 957        response.send(response_payload)?;
 958        Ok(())
 959    }
 960
 961    async fn update_buffer(
 962        self: Arc<Server>,
 963        request: TypedEnvelope<proto::UpdateBuffer>,
 964        response: Response<proto::UpdateBuffer>,
 965    ) -> Result<()> {
 966        let receiver_ids = self
 967            .store()
 968            .await
 969            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 970        broadcast(request.sender_id, receiver_ids, |connection_id| {
 971            self.peer
 972                .forward_send(request.sender_id, connection_id, request.payload.clone())
 973        });
 974        response.send(proto::Ack {})?;
 975        Ok(())
 976    }
 977
 978    async fn update_buffer_file(
 979        self: Arc<Server>,
 980        request: TypedEnvelope<proto::UpdateBufferFile>,
 981    ) -> Result<()> {
 982        let receiver_ids = self
 983            .store()
 984            .await
 985            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 986        broadcast(request.sender_id, receiver_ids, |connection_id| {
 987            self.peer
 988                .forward_send(request.sender_id, connection_id, request.payload.clone())
 989        });
 990        Ok(())
 991    }
 992
 993    async fn buffer_reloaded(
 994        self: Arc<Server>,
 995        request: TypedEnvelope<proto::BufferReloaded>,
 996    ) -> Result<()> {
 997        let receiver_ids = self
 998            .store()
 999            .await
1000            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1001        broadcast(request.sender_id, receiver_ids, |connection_id| {
1002            self.peer
1003                .forward_send(request.sender_id, connection_id, request.payload.clone())
1004        });
1005        Ok(())
1006    }
1007
1008    async fn buffer_saved(
1009        self: Arc<Server>,
1010        request: TypedEnvelope<proto::BufferSaved>,
1011    ) -> Result<()> {
1012        let receiver_ids = self
1013            .store()
1014            .await
1015            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1016        broadcast(request.sender_id, receiver_ids, |connection_id| {
1017            self.peer
1018                .forward_send(request.sender_id, connection_id, request.payload.clone())
1019        });
1020        Ok(())
1021    }
1022
1023    async fn follow(
1024        self: Arc<Self>,
1025        request: TypedEnvelope<proto::Follow>,
1026        response: Response<proto::Follow>,
1027    ) -> Result<()> {
1028        let leader_id = ConnectionId(request.payload.leader_id);
1029        let follower_id = request.sender_id;
1030        if !self
1031            .store()
1032            .await
1033            .project_connection_ids(request.payload.project_id, follower_id)?
1034            .contains(&leader_id)
1035        {
1036            Err(anyhow!("no such peer"))?;
1037        }
1038        let mut response_payload = self
1039            .peer
1040            .forward_request(request.sender_id, leader_id, request.payload)
1041            .await?;
1042        response_payload
1043            .views
1044            .retain(|view| view.leader_id != Some(follower_id.0));
1045        response.send(response_payload)?;
1046        Ok(())
1047    }
1048
1049    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1050        let leader_id = ConnectionId(request.payload.leader_id);
1051        if !self
1052            .store()
1053            .await
1054            .project_connection_ids(request.payload.project_id, request.sender_id)?
1055            .contains(&leader_id)
1056        {
1057            Err(anyhow!("no such peer"))?;
1058        }
1059        self.peer
1060            .forward_send(request.sender_id, leader_id, request.payload)?;
1061        Ok(())
1062    }
1063
1064    async fn update_followers(
1065        self: Arc<Self>,
1066        request: TypedEnvelope<proto::UpdateFollowers>,
1067    ) -> Result<()> {
1068        let connection_ids = self
1069            .store()
1070            .await
1071            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1072        let leader_id = request
1073            .payload
1074            .variant
1075            .as_ref()
1076            .and_then(|variant| match variant {
1077                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1078                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1079                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1080            });
1081        for follower_id in &request.payload.follower_ids {
1082            let follower_id = ConnectionId(*follower_id);
1083            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1084                self.peer
1085                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1086            }
1087        }
1088        Ok(())
1089    }
1090
1091    async fn get_channels(
1092        self: Arc<Server>,
1093        request: TypedEnvelope<proto::GetChannels>,
1094        response: Response<proto::GetChannels>,
1095    ) -> Result<()> {
1096        let user_id = self
1097            .store()
1098            .await
1099            .user_id_for_connection(request.sender_id)?;
1100        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1101        response.send(proto::GetChannelsResponse {
1102            channels: channels
1103                .into_iter()
1104                .map(|chan| proto::Channel {
1105                    id: chan.id.to_proto(),
1106                    name: chan.name,
1107                })
1108                .collect(),
1109        })?;
1110        Ok(())
1111    }
1112
1113    async fn get_users(
1114        self: Arc<Server>,
1115        request: TypedEnvelope<proto::GetUsers>,
1116        response: Response<proto::GetUsers>,
1117    ) -> Result<()> {
1118        let user_ids = request
1119            .payload
1120            .user_ids
1121            .into_iter()
1122            .map(UserId::from_proto)
1123            .collect();
1124        let users = self
1125            .app_state
1126            .db
1127            .get_users_by_ids(user_ids)
1128            .await?
1129            .into_iter()
1130            .map(|user| proto::User {
1131                id: user.id.to_proto(),
1132                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1133                github_login: user.github_login,
1134            })
1135            .collect();
1136        response.send(proto::UsersResponse { users })?;
1137        Ok(())
1138    }
1139
1140    async fn fuzzy_search_users(
1141        self: Arc<Server>,
1142        request: TypedEnvelope<proto::FuzzySearchUsers>,
1143        response: Response<proto::FuzzySearchUsers>,
1144    ) -> Result<()> {
1145        let user_id = self
1146            .store()
1147            .await
1148            .user_id_for_connection(request.sender_id)?;
1149        let query = request.payload.query;
1150        let db = &self.app_state.db;
1151        let users = match query.len() {
1152            0 => vec![],
1153            1 | 2 => db
1154                .get_user_by_github_login(&query)
1155                .await?
1156                .into_iter()
1157                .collect(),
1158            _ => db.fuzzy_search_users(&query, 10).await?,
1159        };
1160        let users = users
1161            .into_iter()
1162            .filter(|user| user.id != user_id)
1163            .map(|user| proto::User {
1164                id: user.id.to_proto(),
1165                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1166                github_login: user.github_login,
1167            })
1168            .collect();
1169        response.send(proto::UsersResponse { users })?;
1170        Ok(())
1171    }
1172
1173    async fn request_contact(
1174        self: Arc<Server>,
1175        request: TypedEnvelope<proto::RequestContact>,
1176        response: Response<proto::RequestContact>,
1177    ) -> Result<()> {
1178        let requester_id = self
1179            .store()
1180            .await
1181            .user_id_for_connection(request.sender_id)?;
1182        let responder_id = UserId::from_proto(request.payload.responder_id);
1183        if requester_id == responder_id {
1184            return Err(anyhow!("cannot add yourself as a contact"))?;
1185        }
1186
1187        self.app_state
1188            .db
1189            .send_contact_request(requester_id, responder_id)
1190            .await?;
1191
1192        // Update outgoing contact requests of requester
1193        let mut update = proto::UpdateContacts::default();
1194        update.outgoing_requests.push(responder_id.to_proto());
1195        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1196            self.peer.send(connection_id, update.clone())?;
1197        }
1198
1199        // Update incoming contact requests of responder
1200        let mut update = proto::UpdateContacts::default();
1201        update
1202            .incoming_requests
1203            .push(proto::IncomingContactRequest {
1204                requester_id: requester_id.to_proto(),
1205                should_notify: true,
1206            });
1207        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1208            self.peer.send(connection_id, update.clone())?;
1209        }
1210
1211        response.send(proto::Ack {})?;
1212        Ok(())
1213    }
1214
1215    async fn respond_to_contact_request(
1216        self: Arc<Server>,
1217        request: TypedEnvelope<proto::RespondToContactRequest>,
1218        response: Response<proto::RespondToContactRequest>,
1219    ) -> Result<()> {
1220        let responder_id = self
1221            .store()
1222            .await
1223            .user_id_for_connection(request.sender_id)?;
1224        let requester_id = UserId::from_proto(request.payload.requester_id);
1225        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1226            self.app_state
1227                .db
1228                .dismiss_contact_notification(responder_id, requester_id)
1229                .await?;
1230        } else {
1231            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1232            self.app_state
1233                .db
1234                .respond_to_contact_request(responder_id, requester_id, accept)
1235                .await?;
1236
1237            let store = self.store().await;
1238            // Update responder with new contact
1239            let mut update = proto::UpdateContacts::default();
1240            if accept {
1241                update
1242                    .contacts
1243                    .push(store.contact_for_user(requester_id, false));
1244            }
1245            update
1246                .remove_incoming_requests
1247                .push(requester_id.to_proto());
1248            for connection_id in store.connection_ids_for_user(responder_id) {
1249                self.peer.send(connection_id, update.clone())?;
1250            }
1251
1252            // Update requester with new contact
1253            let mut update = proto::UpdateContacts::default();
1254            if accept {
1255                update
1256                    .contacts
1257                    .push(store.contact_for_user(responder_id, true));
1258            }
1259            update
1260                .remove_outgoing_requests
1261                .push(responder_id.to_proto());
1262            for connection_id in store.connection_ids_for_user(requester_id) {
1263                self.peer.send(connection_id, update.clone())?;
1264            }
1265        }
1266
1267        response.send(proto::Ack {})?;
1268        Ok(())
1269    }
1270
1271    async fn remove_contact(
1272        self: Arc<Server>,
1273        request: TypedEnvelope<proto::RemoveContact>,
1274        response: Response<proto::RemoveContact>,
1275    ) -> Result<()> {
1276        let requester_id = self
1277            .store()
1278            .await
1279            .user_id_for_connection(request.sender_id)?;
1280        let responder_id = UserId::from_proto(request.payload.user_id);
1281        self.app_state
1282            .db
1283            .remove_contact(requester_id, responder_id)
1284            .await?;
1285
1286        // Update outgoing contact requests of requester
1287        let mut update = proto::UpdateContacts::default();
1288        update
1289            .remove_outgoing_requests
1290            .push(responder_id.to_proto());
1291        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1292            self.peer.send(connection_id, update.clone())?;
1293        }
1294
1295        // Update incoming contact requests of responder
1296        let mut update = proto::UpdateContacts::default();
1297        update
1298            .remove_incoming_requests
1299            .push(requester_id.to_proto());
1300        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1301            self.peer.send(connection_id, update.clone())?;
1302        }
1303
1304        response.send(proto::Ack {})?;
1305        Ok(())
1306    }
1307
1308    async fn join_channel(
1309        self: Arc<Self>,
1310        request: TypedEnvelope<proto::JoinChannel>,
1311        response: Response<proto::JoinChannel>,
1312    ) -> Result<()> {
1313        let user_id = self
1314            .store()
1315            .await
1316            .user_id_for_connection(request.sender_id)?;
1317        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1318        if !self
1319            .app_state
1320            .db
1321            .can_user_access_channel(user_id, channel_id)
1322            .await?
1323        {
1324            Err(anyhow!("access denied"))?;
1325        }
1326
1327        self.store_mut()
1328            .await
1329            .join_channel(request.sender_id, channel_id);
1330        let messages = self
1331            .app_state
1332            .db
1333            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1334            .await?
1335            .into_iter()
1336            .map(|msg| proto::ChannelMessage {
1337                id: msg.id.to_proto(),
1338                body: msg.body,
1339                timestamp: msg.sent_at.unix_timestamp() as u64,
1340                sender_id: msg.sender_id.to_proto(),
1341                nonce: Some(msg.nonce.as_u128().into()),
1342            })
1343            .collect::<Vec<_>>();
1344        response.send(proto::JoinChannelResponse {
1345            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1346            messages,
1347        })?;
1348        Ok(())
1349    }
1350
1351    async fn leave_channel(
1352        self: Arc<Self>,
1353        request: TypedEnvelope<proto::LeaveChannel>,
1354    ) -> Result<()> {
1355        let user_id = self
1356            .store()
1357            .await
1358            .user_id_for_connection(request.sender_id)?;
1359        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1360        if !self
1361            .app_state
1362            .db
1363            .can_user_access_channel(user_id, channel_id)
1364            .await?
1365        {
1366            Err(anyhow!("access denied"))?;
1367        }
1368
1369        self.store_mut()
1370            .await
1371            .leave_channel(request.sender_id, channel_id);
1372
1373        Ok(())
1374    }
1375
1376    async fn send_channel_message(
1377        self: Arc<Self>,
1378        request: TypedEnvelope<proto::SendChannelMessage>,
1379        response: Response<proto::SendChannelMessage>,
1380    ) -> Result<()> {
1381        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1382        let user_id;
1383        let connection_ids;
1384        {
1385            let state = self.store().await;
1386            user_id = state.user_id_for_connection(request.sender_id)?;
1387            connection_ids = state.channel_connection_ids(channel_id)?;
1388        }
1389
1390        // Validate the message body.
1391        let body = request.payload.body.trim().to_string();
1392        if body.len() > MAX_MESSAGE_LEN {
1393            return Err(anyhow!("message is too long"))?;
1394        }
1395        if body.is_empty() {
1396            return Err(anyhow!("message can't be blank"))?;
1397        }
1398
1399        let timestamp = OffsetDateTime::now_utc();
1400        let nonce = request
1401            .payload
1402            .nonce
1403            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1404
1405        let message_id = self
1406            .app_state
1407            .db
1408            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1409            .await?
1410            .to_proto();
1411        let message = proto::ChannelMessage {
1412            sender_id: user_id.to_proto(),
1413            id: message_id,
1414            body,
1415            timestamp: timestamp.unix_timestamp() as u64,
1416            nonce: Some(nonce),
1417        };
1418        broadcast(request.sender_id, connection_ids, |conn_id| {
1419            self.peer.send(
1420                conn_id,
1421                proto::ChannelMessageSent {
1422                    channel_id: channel_id.to_proto(),
1423                    message: Some(message.clone()),
1424                },
1425            )
1426        });
1427        response.send(proto::SendChannelMessageResponse {
1428            message: Some(message),
1429        })?;
1430        Ok(())
1431    }
1432
1433    async fn get_channel_messages(
1434        self: Arc<Self>,
1435        request: TypedEnvelope<proto::GetChannelMessages>,
1436        response: Response<proto::GetChannelMessages>,
1437    ) -> Result<()> {
1438        let user_id = self
1439            .store()
1440            .await
1441            .user_id_for_connection(request.sender_id)?;
1442        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1443        if !self
1444            .app_state
1445            .db
1446            .can_user_access_channel(user_id, channel_id)
1447            .await?
1448        {
1449            Err(anyhow!("access denied"))?;
1450        }
1451
1452        let messages = self
1453            .app_state
1454            .db
1455            .get_channel_messages(
1456                channel_id,
1457                MESSAGE_COUNT_PER_PAGE,
1458                Some(MessageId::from_proto(request.payload.before_message_id)),
1459            )
1460            .await?
1461            .into_iter()
1462            .map(|msg| proto::ChannelMessage {
1463                id: msg.id.to_proto(),
1464                body: msg.body,
1465                timestamp: msg.sent_at.unix_timestamp() as u64,
1466                sender_id: msg.sender_id.to_proto(),
1467                nonce: Some(msg.nonce.as_u128().into()),
1468            })
1469            .collect::<Vec<_>>();
1470        response.send(proto::GetChannelMessagesResponse {
1471            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1472            messages,
1473        })?;
1474        Ok(())
1475    }
1476
1477    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1478        #[cfg(test)]
1479        tokio::task::yield_now().await;
1480        let guard = self.store.read().await;
1481        #[cfg(test)]
1482        tokio::task::yield_now().await;
1483        StoreReadGuard {
1484            guard,
1485            _not_send: PhantomData,
1486        }
1487    }
1488
1489    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1490        #[cfg(test)]
1491        tokio::task::yield_now().await;
1492        let guard = self.store.write().await;
1493        #[cfg(test)]
1494        tokio::task::yield_now().await;
1495        StoreWriteGuard {
1496            guard,
1497            _not_send: PhantomData,
1498        }
1499    }
1500
1501    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1502        ServerSnapshot {
1503            store: self.store.read().await,
1504            peer: &self.peer,
1505        }
1506    }
1507}
1508
1509impl<'a> Deref for StoreReadGuard<'a> {
1510    type Target = Store;
1511
1512    fn deref(&self) -> &Self::Target {
1513        &*self.guard
1514    }
1515}
1516
1517impl<'a> Deref for StoreWriteGuard<'a> {
1518    type Target = Store;
1519
1520    fn deref(&self) -> &Self::Target {
1521        &*self.guard
1522    }
1523}
1524
1525impl<'a> DerefMut for StoreWriteGuard<'a> {
1526    fn deref_mut(&mut self) -> &mut Self::Target {
1527        &mut *self.guard
1528    }
1529}
1530
1531impl<'a> Drop for StoreWriteGuard<'a> {
1532    fn drop(&mut self) {
1533        #[cfg(test)]
1534        self.check_invariants();
1535
1536        let metrics = self.metrics();
1537        tracing::info!(
1538            connections = metrics.connections,
1539            registered_projects = metrics.registered_projects,
1540            shared_projects = metrics.shared_projects,
1541            "metrics"
1542        );
1543    }
1544}
1545
1546impl Executor for RealExecutor {
1547    type Sleep = Sleep;
1548
1549    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1550        tokio::task::spawn(future);
1551    }
1552
1553    fn sleep(&self, duration: Duration) -> Self::Sleep {
1554        tokio::time::sleep(duration)
1555    }
1556}
1557
1558fn broadcast<F>(
1559    sender_id: ConnectionId,
1560    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1561    mut f: F,
1562) where
1563    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1564{
1565    for receiver_id in receiver_ids {
1566        if receiver_id != sender_id {
1567            f(receiver_id).trace_err();
1568        }
1569    }
1570}
1571
1572lazy_static! {
1573    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1574}
1575
1576pub struct ProtocolVersion(u32);
1577
1578impl Header for ProtocolVersion {
1579    fn name() -> &'static HeaderName {
1580        &ZED_PROTOCOL_VERSION
1581    }
1582
1583    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1584    where
1585        Self: Sized,
1586        I: Iterator<Item = &'i axum::http::HeaderValue>,
1587    {
1588        let version = values
1589            .next()
1590            .ok_or_else(|| axum::headers::Error::invalid())?
1591            .to_str()
1592            .map_err(|_| axum::headers::Error::invalid())?
1593            .parse()
1594            .map_err(|_| axum::headers::Error::invalid())?;
1595        Ok(Self(version))
1596    }
1597
1598    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1599        values.extend([self.0.to_string().parse().unwrap()]);
1600    }
1601}
1602
1603pub fn routes(server: Arc<Server>) -> Router<Body> {
1604    Router::new()
1605        .route("/rpc", get(handle_websocket_request))
1606        .layer(
1607            ServiceBuilder::new()
1608                .layer(Extension(server.app_state.clone()))
1609                .layer(middleware::from_fn(auth::validate_header))
1610                .layer(Extension(server)),
1611        )
1612}
1613
1614pub async fn handle_websocket_request(
1615    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1616    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1617    Extension(server): Extension<Arc<Server>>,
1618    Extension(user): Extension<User>,
1619    ws: WebSocketUpgrade,
1620) -> axum::response::Response {
1621    if protocol_version != rpc::PROTOCOL_VERSION {
1622        return (
1623            StatusCode::UPGRADE_REQUIRED,
1624            "client must be upgraded".to_string(),
1625        )
1626            .into_response();
1627    }
1628    let socket_address = socket_address.to_string();
1629    ws.on_upgrade(move |socket| {
1630        use util::ResultExt;
1631        let socket = socket
1632            .map_ok(to_tungstenite_message)
1633            .err_into()
1634            .with(|message| async move { Ok(to_axum_message(message)) });
1635        let connection = Connection::new(Box::pin(socket));
1636        async move {
1637            server
1638                .handle_connection(connection, socket_address, user, None, RealExecutor)
1639                .await
1640                .log_err();
1641        }
1642    })
1643}
1644
1645fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1646    match message {
1647        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1648        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1649        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1650        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1651        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1652            code: frame.code.into(),
1653            reason: frame.reason,
1654        })),
1655    }
1656}
1657
1658fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1659    match message {
1660        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1661        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1662        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1663        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1664        AxumMessage::Close(frame) => {
1665            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1666                code: frame.code.into(),
1667                reason: frame.reason,
1668            }))
1669        }
1670    }
1671}
1672
1673pub trait ResultExt {
1674    type Ok;
1675
1676    fn trace_err(self) -> Option<Self::Ok>;
1677}
1678
1679impl<T, E> ResultExt for Result<T, E>
1680where
1681    E: std::fmt::Debug,
1682{
1683    type Ok = T;
1684
1685    fn trace_err(self) -> Option<T> {
1686        match self {
1687            Ok(value) => Some(value),
1688            Err(error) => {
1689                tracing::error!("{:?}", error);
1690                None
1691            }
1692        }
1693    }
1694}