rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use rpc::{
  33    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  34    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  35};
  36use serde::{Serialize, Serializer};
  37use std::{
  38    any::TypeId,
  39    future::Future,
  40    marker::PhantomData,
  41    net::SocketAddr,
  42    ops::{Deref, DerefMut},
  43    rc::Rc,
  44    sync::{
  45        atomic::{AtomicBool, Ordering::SeqCst},
  46        Arc,
  47    },
  48    time::Duration,
  49};
  50use time::OffsetDateTime;
  51use tokio::{
  52    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  53    time::Sleep,
  54};
  55use tower::ServiceBuilder;
  56use tracing::{info_span, instrument, Instrument};
  57
  58pub use store::{Store, Worktree};
  59
  60type MessageHandler =
  61    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  62
  63struct Response<R> {
  64    server: Arc<Server>,
  65    receipt: Receipt<R>,
  66    responded: Arc<AtomicBool>,
  67}
  68
  69impl<R: RequestMessage> Response<R> {
  70    fn send(self, payload: R::Response) -> Result<()> {
  71        self.responded.store(true, SeqCst);
  72        self.server.peer.respond(self.receipt, payload)?;
  73        Ok(())
  74    }
  75
  76    fn into_receipt(self) -> Receipt<R> {
  77        self.responded.store(true, SeqCst);
  78        self.receipt
  79    }
  80}
  81
  82pub struct Server {
  83    peer: Arc<Peer>,
  84    pub(crate) store: RwLock<Store>,
  85    app_state: Arc<AppState>,
  86    handlers: HashMap<TypeId, MessageHandler>,
  87    notifications: Option<mpsc::UnboundedSender<()>>,
  88}
  89
  90pub trait Executor: Send + Clone {
  91    type Sleep: Send + Future;
  92    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
  93    fn sleep(&self, duration: Duration) -> Self::Sleep;
  94}
  95
  96#[derive(Clone)]
  97pub struct RealExecutor;
  98
  99const MESSAGE_COUNT_PER_PAGE: usize = 100;
 100const MAX_MESSAGE_LEN: usize = 1024;
 101
 102struct StoreReadGuard<'a> {
 103    guard: RwLockReadGuard<'a, Store>,
 104    _not_send: PhantomData<Rc<()>>,
 105}
 106
 107struct StoreWriteGuard<'a> {
 108    guard: RwLockWriteGuard<'a, Store>,
 109    _not_send: PhantomData<Rc<()>>,
 110}
 111
 112#[derive(Serialize)]
 113pub struct ServerSnapshot<'a> {
 114    peer: &'a Peer,
 115    #[serde(serialize_with = "serialize_deref")]
 116    store: RwLockReadGuard<'a, Store>,
 117}
 118
 119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 120where
 121    S: Serializer,
 122    T: Deref<Target = U>,
 123    U: Serialize,
 124{
 125    Serialize::serialize(value.deref(), serializer)
 126}
 127
 128impl Server {
 129    pub fn new(
 130        app_state: Arc<AppState>,
 131        notifications: Option<mpsc::UnboundedSender<()>>,
 132    ) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138            notifications,
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::register_project)
 144            .add_message_handler(Server::unregister_project)
 145            .add_request_handler(Server::join_project)
 146            .add_message_handler(Server::leave_project)
 147            .add_message_handler(Server::respond_to_join_project_request)
 148            .add_request_handler(Server::register_worktree)
 149            .add_message_handler(Server::unregister_worktree)
 150            .add_request_handler(Server::update_worktree)
 151            .add_message_handler(Server::start_language_server)
 152            .add_message_handler(Server::update_language_server)
 153            .add_message_handler(Server::update_diagnostic_summary)
 154            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 155            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 156            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 157            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 158            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 159            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 160            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 161            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 163            .add_request_handler(
 164                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 165            )
 166            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 167            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 168            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 169            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 170            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 171            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 172            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 173            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 174            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 175            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 176            .add_request_handler(Server::update_buffer)
 177            .add_message_handler(Server::update_buffer_file)
 178            .add_message_handler(Server::buffer_reloaded)
 179            .add_message_handler(Server::buffer_saved)
 180            .add_request_handler(Server::save_buffer)
 181            .add_request_handler(Server::get_channels)
 182            .add_request_handler(Server::get_users)
 183            .add_request_handler(Server::fuzzy_search_users)
 184            .add_request_handler(Server::request_contact)
 185            .add_request_handler(Server::remove_contact)
 186            .add_request_handler(Server::respond_to_contact_request)
 187            .add_request_handler(Server::join_channel)
 188            .add_message_handler(Server::leave_channel)
 189            .add_request_handler(Server::send_channel_message)
 190            .add_request_handler(Server::follow)
 191            .add_message_handler(Server::unfollow)
 192            .add_message_handler(Server::update_followers)
 193            .add_request_handler(Server::get_channel_messages);
 194
 195        Arc::new(server)
 196    }
 197
 198    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 199    where
 200        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 201        Fut: 'static + Send + Future<Output = Result<()>>,
 202        M: EnvelopedMessage,
 203    {
 204        let prev_handler = self.handlers.insert(
 205            TypeId::of::<M>(),
 206            Box::new(move |server, envelope| {
 207                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 208                let span = info_span!(
 209                    "handle message",
 210                    payload_type = envelope.payload_type_name(),
 211                    payload = format!("{:?}", envelope.payload).as_str(),
 212                );
 213                let future = (handler)(server, *envelope);
 214                async move {
 215                    if let Err(error) = future.await {
 216                        tracing::error!(%error, "error handling message");
 217                    }
 218                }
 219                .instrument(span)
 220                .boxed()
 221            }),
 222        );
 223        if prev_handler.is_some() {
 224            panic!("registered a handler for the same message twice");
 225        }
 226        self
 227    }
 228
 229    /// Handle a request while holding a lock to the store. This is useful when we're registering
 230    /// a connection but we want to respond on the connection before anybody else can send on it.
 231    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 232    where
 233        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 234        Fut: Send + Future<Output = Result<()>>,
 235        M: RequestMessage,
 236    {
 237        let handler = Arc::new(handler);
 238        self.add_message_handler(move |server, envelope| {
 239            let receipt = envelope.receipt();
 240            let handler = handler.clone();
 241            async move {
 242                let responded = Arc::new(AtomicBool::default());
 243                let response = Response {
 244                    server: server.clone(),
 245                    responded: responded.clone(),
 246                    receipt: envelope.receipt(),
 247                };
 248                match (handler)(server.clone(), envelope, response).await {
 249                    Ok(()) => {
 250                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 251                            Ok(())
 252                        } else {
 253                            Err(anyhow!("handler did not send a response"))?
 254                        }
 255                    }
 256                    Err(error) => {
 257                        server.peer.respond_with_error(
 258                            receipt,
 259                            proto::Error {
 260                                message: error.to_string(),
 261                            },
 262                        )?;
 263                        Err(error)
 264                    }
 265                }
 266            }
 267        })
 268    }
 269
 270    pub fn handle_connection<E: Executor>(
 271        self: &Arc<Self>,
 272        connection: Connection,
 273        address: String,
 274        user: User,
 275        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 276        executor: E,
 277    ) -> impl Future<Output = Result<()>> {
 278        let mut this = self.clone();
 279        let user_id = user.id;
 280        let login = user.github_login;
 281        let span = info_span!("handle connection", %user_id, %login, %address);
 282        async move {
 283            let (connection_id, handle_io, mut incoming_rx) = this
 284                .peer
 285                .add_connection(connection, {
 286                    let executor = executor.clone();
 287                    move |duration| {
 288                        let timer = executor.sleep(duration);
 289                        async move {
 290                            timer.await;
 291                        }
 292                    }
 293                })
 294                .await;
 295
 296            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 297
 298            if let Some(send_connection_id) = send_connection_id.as_mut() {
 299                let _ = send_connection_id.send(connection_id).await;
 300            }
 301
 302            if !user.connected_once {
 303                this.peer.send(connection_id, proto::ShowContacts {})?;
 304                this.app_state.db.set_user_connected_once(user_id, true).await?;
 305            }
 306
 307            let (contacts, invite_code) = future::try_join(
 308                this.app_state.db.get_contacts(user_id),
 309                this.app_state.db.get_invite_code_for_user(user_id)
 310            ).await?;
 311
 312            {
 313                let mut store = this.store_mut().await;
 314                store.add_connection(connection_id, user_id);
 315                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 316
 317                if let Some((code, count)) = invite_code {
 318                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 319                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 320                        count,
 321                    })?;
 322                }
 323            }
 324            this.update_user_contacts(user_id).await?;
 325
 326            let handle_io = handle_io.fuse();
 327            futures::pin_mut!(handle_io);
 328            loop {
 329                let next_message = incoming_rx.next().fuse();
 330                futures::pin_mut!(next_message);
 331                futures::select_biased! {
 332                    result = handle_io => {
 333                        if let Err(error) = result {
 334                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 335                        }
 336                        break;
 337                    }
 338                    message = next_message => {
 339                        if let Some(message) = message {
 340                            let type_name = message.payload_type_name();
 341                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 342                            async {
 343                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 344                                    let notifications = this.notifications.clone();
 345                                    let is_background = message.is_background();
 346                                    let handle_message = (handler)(this.clone(), message);
 347                                    let handle_message = async move {
 348                                        handle_message.await;
 349                                        if let Some(mut notifications) = notifications {
 350                                            let _ = notifications.send(()).await;
 351                                        }
 352                                    };
 353                                    if is_background {
 354                                        executor.spawn_detached(handle_message);
 355                                    } else {
 356                                        handle_message.await;
 357                                    }
 358                                } else {
 359                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 360                                }
 361                            }.instrument(span).await;
 362                        } else {
 363                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 364                            break;
 365                        }
 366                    }
 367                }
 368            }
 369
 370            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 371            if let Err(error) = this.sign_out(connection_id).await {
 372                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 373            }
 374
 375            Ok(())
 376        }.instrument(span)
 377    }
 378
 379    #[instrument(skip(self), err)]
 380    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 381        self.peer.disconnect(connection_id);
 382
 383        let removed_user_id = {
 384            let mut store = self.store_mut().await;
 385            let removed_connection = store.remove_connection(connection_id)?;
 386
 387            for (project_id, project) in removed_connection.hosted_projects {
 388                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 389                    self.peer
 390                        .send(conn_id, proto::UnregisterProject { project_id })
 391                });
 392
 393                for (_, receipts) in project.join_requests {
 394                    for receipt in receipts {
 395                        self.peer.respond(
 396                            receipt,
 397                            proto::JoinProjectResponse {
 398                                variant: Some(proto::join_project_response::Variant::Decline(
 399                                    proto::join_project_response::Decline {
 400                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 401                                    },
 402                                )),
 403                            },
 404                        )?;
 405                    }
 406                }
 407            }
 408
 409            for project_id in removed_connection.guest_project_ids {
 410                if let Some(project) = store.project(project_id).trace_err() {
 411                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 412                        self.peer.send(
 413                            conn_id,
 414                            proto::RemoveProjectCollaborator {
 415                                project_id,
 416                                peer_id: connection_id.0,
 417                            },
 418                        )
 419                    });
 420                    if project.guests.is_empty() {
 421                        self.peer
 422                            .send(
 423                                project.host_connection_id,
 424                                proto::ProjectUnshared { project_id },
 425                            )
 426                            .trace_err();
 427                    }
 428                }
 429            }
 430
 431            removed_connection.user_id
 432        };
 433
 434        self.update_user_contacts(removed_user_id).await?;
 435
 436        Ok(())
 437    }
 438
 439    pub async fn invite_code_redeemed(
 440        self: &Arc<Self>,
 441        code: &str,
 442        invitee_id: UserId,
 443    ) -> Result<()> {
 444        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 445        let store = self.store().await;
 446        let invitee_contact = store.contact_for_user(invitee_id, true);
 447        for connection_id in store.connection_ids_for_user(user.id) {
 448            self.peer.send(
 449                connection_id,
 450                proto::UpdateContacts {
 451                    contacts: vec![invitee_contact.clone()],
 452                    ..Default::default()
 453                },
 454            )?;
 455            self.peer.send(
 456                connection_id,
 457                proto::UpdateInviteInfo {
 458                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 459                    count: user.invite_count as u32,
 460                },
 461            )?;
 462        }
 463        Ok(())
 464    }
 465
 466    async fn ping(
 467        self: Arc<Server>,
 468        _: TypedEnvelope<proto::Ping>,
 469        response: Response<proto::Ping>,
 470    ) -> Result<()> {
 471        response.send(proto::Ack {})?;
 472        Ok(())
 473    }
 474
 475    async fn register_project(
 476        self: Arc<Server>,
 477        request: TypedEnvelope<proto::RegisterProject>,
 478        response: Response<proto::RegisterProject>,
 479    ) -> Result<()> {
 480        let user_id;
 481        let project_id;
 482        {
 483            let mut state = self.store_mut().await;
 484            user_id = state.user_id_for_connection(request.sender_id)?;
 485            project_id = state.register_project(request.sender_id, user_id);
 486        };
 487        self.update_user_contacts(user_id).await?;
 488        response.send(proto::RegisterProjectResponse { project_id })?;
 489        Ok(())
 490    }
 491
 492    async fn unregister_project(
 493        self: Arc<Server>,
 494        request: TypedEnvelope<proto::UnregisterProject>,
 495    ) -> Result<()> {
 496        let (user_id, project) = {
 497            let mut state = self.store_mut().await;
 498            let project =
 499                state.unregister_project(request.payload.project_id, request.sender_id)?;
 500            (state.user_id_for_connection(request.sender_id)?, project)
 501        };
 502
 503        broadcast(
 504            request.sender_id,
 505            project.guests.keys().copied(),
 506            |conn_id| {
 507                self.peer.send(
 508                    conn_id,
 509                    proto::UnregisterProject {
 510                        project_id: request.payload.project_id,
 511                    },
 512                )
 513            },
 514        );
 515        for (_, receipts) in project.join_requests {
 516            for receipt in receipts {
 517                self.peer.respond(
 518                    receipt,
 519                    proto::JoinProjectResponse {
 520                        variant: Some(proto::join_project_response::Variant::Decline(
 521                            proto::join_project_response::Decline {
 522                                reason: proto::join_project_response::decline::Reason::Closed
 523                                    as i32,
 524                            },
 525                        )),
 526                    },
 527                )?;
 528            }
 529        }
 530
 531        self.update_user_contacts(user_id).await?;
 532        Ok(())
 533    }
 534
 535    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 536        let contacts = self.app_state.db.get_contacts(user_id).await?;
 537        let store = self.store().await;
 538        let updated_contact = store.contact_for_user(user_id, false);
 539        for contact in contacts {
 540            if let db::Contact::Accepted {
 541                user_id: contact_user_id,
 542                ..
 543            } = contact
 544            {
 545                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 546                    self.peer
 547                        .send(
 548                            contact_conn_id,
 549                            proto::UpdateContacts {
 550                                contacts: vec![updated_contact.clone()],
 551                                remove_contacts: Default::default(),
 552                                incoming_requests: Default::default(),
 553                                remove_incoming_requests: Default::default(),
 554                                outgoing_requests: Default::default(),
 555                                remove_outgoing_requests: Default::default(),
 556                            },
 557                        )
 558                        .trace_err();
 559                }
 560            }
 561        }
 562        Ok(())
 563    }
 564
 565    async fn join_project(
 566        self: Arc<Server>,
 567        request: TypedEnvelope<proto::JoinProject>,
 568        response: Response<proto::JoinProject>,
 569    ) -> Result<()> {
 570        let project_id = request.payload.project_id;
 571        let host_user_id;
 572        let guest_user_id;
 573        let host_connection_id;
 574        {
 575            let state = self.store().await;
 576            let project = state.project(project_id)?;
 577            host_user_id = project.host_user_id;
 578            host_connection_id = project.host_connection_id;
 579            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 580        };
 581
 582        let has_contact = self
 583            .app_state
 584            .db
 585            .has_contact(guest_user_id, host_user_id)
 586            .await?;
 587        if !has_contact {
 588            return Err(anyhow!("no such project"))?;
 589        }
 590
 591        self.store_mut().await.request_join_project(
 592            guest_user_id,
 593            project_id,
 594            response.into_receipt(),
 595        )?;
 596        self.peer.send(
 597            host_connection_id,
 598            proto::RequestJoinProject {
 599                project_id,
 600                requester_id: guest_user_id.to_proto(),
 601            },
 602        )?;
 603        Ok(())
 604    }
 605
 606    async fn respond_to_join_project_request(
 607        self: Arc<Server>,
 608        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 609    ) -> Result<()> {
 610        let host_user_id;
 611
 612        {
 613            let mut state = self.store_mut().await;
 614            let project_id = request.payload.project_id;
 615            let project = state.project(project_id)?;
 616            if project.host_connection_id != request.sender_id {
 617                Err(anyhow!("no such connection"))?;
 618            }
 619
 620            host_user_id = project.host_user_id;
 621            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 622
 623            if !request.payload.allow {
 624                let receipts = state
 625                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 626                    .ok_or_else(|| anyhow!("no such request"))?;
 627                for receipt in receipts {
 628                    self.peer.respond(
 629                        receipt,
 630                        proto::JoinProjectResponse {
 631                            variant: Some(proto::join_project_response::Variant::Decline(
 632                                proto::join_project_response::Decline {
 633                                    reason: proto::join_project_response::decline::Reason::Declined
 634                                        as i32,
 635                                },
 636                            )),
 637                        },
 638                    )?;
 639                }
 640                return Ok(());
 641            }
 642
 643            let (receipts_with_replica_ids, project) = state
 644                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 645                .ok_or_else(|| anyhow!("no such request"))?;
 646
 647            let peer_count = project.guests.len();
 648            let mut collaborators = Vec::with_capacity(peer_count);
 649            collaborators.push(proto::Collaborator {
 650                peer_id: project.host_connection_id.0,
 651                replica_id: 0,
 652                user_id: project.host_user_id.to_proto(),
 653            });
 654            let worktrees = project
 655                .worktrees
 656                .iter()
 657                .filter_map(|(id, shared_worktree)| {
 658                    let worktree = project.worktrees.get(&id)?;
 659                    Some(proto::Worktree {
 660                        id: *id,
 661                        root_name: worktree.root_name.clone(),
 662                        entries: shared_worktree.entries.values().cloned().collect(),
 663                        diagnostic_summaries: shared_worktree
 664                            .diagnostic_summaries
 665                            .values()
 666                            .cloned()
 667                            .collect(),
 668                        visible: worktree.visible,
 669                        scan_id: shared_worktree.scan_id,
 670                    })
 671                })
 672                .collect::<Vec<_>>();
 673
 674            // Add all guests other than the requesting user's own connections as collaborators
 675            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 676                if receipts_with_replica_ids
 677                    .iter()
 678                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 679                {
 680                    collaborators.push(proto::Collaborator {
 681                        peer_id: peer_conn_id.0,
 682                        replica_id: *peer_replica_id as u32,
 683                        user_id: peer_user_id.to_proto(),
 684                    });
 685                }
 686            }
 687
 688            for conn_id in project.connection_ids() {
 689                for (receipt, replica_id) in &receipts_with_replica_ids {
 690                    if conn_id != receipt.sender_id {
 691                        self.peer.send(
 692                            conn_id,
 693                            proto::AddProjectCollaborator {
 694                                project_id,
 695                                collaborator: Some(proto::Collaborator {
 696                                    peer_id: receipt.sender_id.0,
 697                                    replica_id: *replica_id as u32,
 698                                    user_id: guest_user_id.to_proto(),
 699                                }),
 700                            },
 701                        )?;
 702                    }
 703                }
 704            }
 705
 706            for (receipt, replica_id) in receipts_with_replica_ids {
 707                self.peer.respond(
 708                    receipt,
 709                    proto::JoinProjectResponse {
 710                        variant: Some(proto::join_project_response::Variant::Accept(
 711                            proto::join_project_response::Accept {
 712                                worktrees: worktrees.clone(),
 713                                replica_id: replica_id as u32,
 714                                collaborators: collaborators.clone(),
 715                                language_servers: project.language_servers.clone(),
 716                            },
 717                        )),
 718                    },
 719                )?;
 720            }
 721        }
 722
 723        self.update_user_contacts(host_user_id).await?;
 724        Ok(())
 725    }
 726
 727    async fn leave_project(
 728        self: Arc<Server>,
 729        request: TypedEnvelope<proto::LeaveProject>,
 730    ) -> Result<()> {
 731        let sender_id = request.sender_id;
 732        let project_id = request.payload.project_id;
 733        let project;
 734        {
 735            let mut store = self.store_mut().await;
 736            project = store.leave_project(sender_id, project_id)?;
 737
 738            if project.remove_collaborator {
 739                broadcast(sender_id, project.connection_ids, |conn_id| {
 740                    self.peer.send(
 741                        conn_id,
 742                        proto::RemoveProjectCollaborator {
 743                            project_id,
 744                            peer_id: sender_id.0,
 745                        },
 746                    )
 747                });
 748            }
 749
 750            if let Some(requester_id) = project.cancel_request {
 751                self.peer.send(
 752                    project.host_connection_id,
 753                    proto::JoinProjectRequestCancelled {
 754                        project_id,
 755                        requester_id: requester_id.to_proto(),
 756                    },
 757                )?;
 758            }
 759
 760            if project.unshare {
 761                self.peer.send(
 762                    project.host_connection_id,
 763                    proto::ProjectUnshared { project_id },
 764                )?;
 765            }
 766        }
 767        self.update_user_contacts(project.host_user_id).await?;
 768        Ok(())
 769    }
 770
 771    async fn register_worktree(
 772        self: Arc<Server>,
 773        request: TypedEnvelope<proto::RegisterWorktree>,
 774        response: Response<proto::RegisterWorktree>,
 775    ) -> Result<()> {
 776        let host_user_id;
 777        {
 778            let mut state = self.store_mut().await;
 779            host_user_id = state.user_id_for_connection(request.sender_id)?;
 780
 781            let guest_connection_ids = state
 782                .read_project(request.payload.project_id, request.sender_id)?
 783                .guest_connection_ids();
 784            state.register_worktree(
 785                request.payload.project_id,
 786                request.payload.worktree_id,
 787                request.sender_id,
 788                Worktree {
 789                    root_name: request.payload.root_name.clone(),
 790                    visible: request.payload.visible,
 791                    ..Default::default()
 792                },
 793            )?;
 794
 795            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 796                self.peer
 797                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 798            });
 799        }
 800        self.update_user_contacts(host_user_id).await?;
 801        response.send(proto::Ack {})?;
 802        Ok(())
 803    }
 804
 805    async fn unregister_worktree(
 806        self: Arc<Server>,
 807        request: TypedEnvelope<proto::UnregisterWorktree>,
 808    ) -> Result<()> {
 809        let host_user_id;
 810        let project_id = request.payload.project_id;
 811        let worktree_id = request.payload.worktree_id;
 812        {
 813            let mut state = self.store_mut().await;
 814            let (_, guest_connection_ids) =
 815                state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
 816            host_user_id = state.user_id_for_connection(request.sender_id)?;
 817            broadcast(request.sender_id, guest_connection_ids, |conn_id| {
 818                self.peer.send(
 819                    conn_id,
 820                    proto::UnregisterWorktree {
 821                        project_id,
 822                        worktree_id,
 823                    },
 824                )
 825            });
 826        }
 827        self.update_user_contacts(host_user_id).await?;
 828        Ok(())
 829    }
 830
 831    async fn update_worktree(
 832        self: Arc<Server>,
 833        request: TypedEnvelope<proto::UpdateWorktree>,
 834        response: Response<proto::UpdateWorktree>,
 835    ) -> Result<()> {
 836        let connection_ids = self.store_mut().await.update_worktree(
 837            request.sender_id,
 838            request.payload.project_id,
 839            request.payload.worktree_id,
 840            &request.payload.removed_entries,
 841            &request.payload.updated_entries,
 842            request.payload.scan_id,
 843        )?;
 844
 845        broadcast(request.sender_id, connection_ids, |connection_id| {
 846            self.peer
 847                .forward_send(request.sender_id, connection_id, request.payload.clone())
 848        });
 849        response.send(proto::Ack {})?;
 850        Ok(())
 851    }
 852
 853    async fn update_diagnostic_summary(
 854        self: Arc<Server>,
 855        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 856    ) -> Result<()> {
 857        let summary = request
 858            .payload
 859            .summary
 860            .clone()
 861            .ok_or_else(|| anyhow!("invalid summary"))?;
 862        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 863            request.payload.project_id,
 864            request.payload.worktree_id,
 865            request.sender_id,
 866            summary,
 867        )?;
 868
 869        broadcast(request.sender_id, receiver_ids, |connection_id| {
 870            self.peer
 871                .forward_send(request.sender_id, connection_id, request.payload.clone())
 872        });
 873        Ok(())
 874    }
 875
 876    async fn start_language_server(
 877        self: Arc<Server>,
 878        request: TypedEnvelope<proto::StartLanguageServer>,
 879    ) -> Result<()> {
 880        let receiver_ids = self.store_mut().await.start_language_server(
 881            request.payload.project_id,
 882            request.sender_id,
 883            request
 884                .payload
 885                .server
 886                .clone()
 887                .ok_or_else(|| anyhow!("invalid language server"))?,
 888        )?;
 889        broadcast(request.sender_id, receiver_ids, |connection_id| {
 890            self.peer
 891                .forward_send(request.sender_id, connection_id, request.payload.clone())
 892        });
 893        Ok(())
 894    }
 895
 896    async fn update_language_server(
 897        self: Arc<Server>,
 898        request: TypedEnvelope<proto::UpdateLanguageServer>,
 899    ) -> Result<()> {
 900        let receiver_ids = self
 901            .store()
 902            .await
 903            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 904        broadcast(request.sender_id, receiver_ids, |connection_id| {
 905            self.peer
 906                .forward_send(request.sender_id, connection_id, request.payload.clone())
 907        });
 908        Ok(())
 909    }
 910
 911    async fn forward_project_request<T>(
 912        self: Arc<Server>,
 913        request: TypedEnvelope<T>,
 914        response: Response<T>,
 915    ) -> Result<()>
 916    where
 917        T: EntityMessage + RequestMessage,
 918    {
 919        let host_connection_id = self
 920            .store()
 921            .await
 922            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 923            .host_connection_id;
 924
 925        response.send(
 926            self.peer
 927                .forward_request(request.sender_id, host_connection_id, request.payload)
 928                .await?,
 929        )?;
 930        Ok(())
 931    }
 932
 933    async fn save_buffer(
 934        self: Arc<Server>,
 935        request: TypedEnvelope<proto::SaveBuffer>,
 936        response: Response<proto::SaveBuffer>,
 937    ) -> Result<()> {
 938        let host = self
 939            .store()
 940            .await
 941            .read_project(request.payload.project_id, request.sender_id)?
 942            .host_connection_id;
 943        let response_payload = self
 944            .peer
 945            .forward_request(request.sender_id, host, request.payload.clone())
 946            .await?;
 947
 948        let mut guests = self
 949            .store()
 950            .await
 951            .read_project(request.payload.project_id, request.sender_id)?
 952            .connection_ids();
 953        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 954        broadcast(host, guests, |conn_id| {
 955            self.peer
 956                .forward_send(host, conn_id, response_payload.clone())
 957        });
 958        response.send(response_payload)?;
 959        Ok(())
 960    }
 961
 962    async fn update_buffer(
 963        self: Arc<Server>,
 964        request: TypedEnvelope<proto::UpdateBuffer>,
 965        response: Response<proto::UpdateBuffer>,
 966    ) -> Result<()> {
 967        let receiver_ids = self
 968            .store()
 969            .await
 970            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 971        broadcast(request.sender_id, receiver_ids, |connection_id| {
 972            self.peer
 973                .forward_send(request.sender_id, connection_id, request.payload.clone())
 974        });
 975        response.send(proto::Ack {})?;
 976        Ok(())
 977    }
 978
 979    async fn update_buffer_file(
 980        self: Arc<Server>,
 981        request: TypedEnvelope<proto::UpdateBufferFile>,
 982    ) -> Result<()> {
 983        let receiver_ids = self
 984            .store()
 985            .await
 986            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 987        broadcast(request.sender_id, receiver_ids, |connection_id| {
 988            self.peer
 989                .forward_send(request.sender_id, connection_id, request.payload.clone())
 990        });
 991        Ok(())
 992    }
 993
 994    async fn buffer_reloaded(
 995        self: Arc<Server>,
 996        request: TypedEnvelope<proto::BufferReloaded>,
 997    ) -> Result<()> {
 998        let receiver_ids = self
 999            .store()
1000            .await
1001            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1002        broadcast(request.sender_id, receiver_ids, |connection_id| {
1003            self.peer
1004                .forward_send(request.sender_id, connection_id, request.payload.clone())
1005        });
1006        Ok(())
1007    }
1008
1009    async fn buffer_saved(
1010        self: Arc<Server>,
1011        request: TypedEnvelope<proto::BufferSaved>,
1012    ) -> Result<()> {
1013        let receiver_ids = self
1014            .store()
1015            .await
1016            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1017        broadcast(request.sender_id, receiver_ids, |connection_id| {
1018            self.peer
1019                .forward_send(request.sender_id, connection_id, request.payload.clone())
1020        });
1021        Ok(())
1022    }
1023
1024    async fn follow(
1025        self: Arc<Self>,
1026        request: TypedEnvelope<proto::Follow>,
1027        response: Response<proto::Follow>,
1028    ) -> Result<()> {
1029        let leader_id = ConnectionId(request.payload.leader_id);
1030        let follower_id = request.sender_id;
1031        if !self
1032            .store()
1033            .await
1034            .project_connection_ids(request.payload.project_id, follower_id)?
1035            .contains(&leader_id)
1036        {
1037            Err(anyhow!("no such peer"))?;
1038        }
1039        let mut response_payload = self
1040            .peer
1041            .forward_request(request.sender_id, leader_id, request.payload)
1042            .await?;
1043        response_payload
1044            .views
1045            .retain(|view| view.leader_id != Some(follower_id.0));
1046        response.send(response_payload)?;
1047        Ok(())
1048    }
1049
1050    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1051        let leader_id = ConnectionId(request.payload.leader_id);
1052        if !self
1053            .store()
1054            .await
1055            .project_connection_ids(request.payload.project_id, request.sender_id)?
1056            .contains(&leader_id)
1057        {
1058            Err(anyhow!("no such peer"))?;
1059        }
1060        self.peer
1061            .forward_send(request.sender_id, leader_id, request.payload)?;
1062        Ok(())
1063    }
1064
1065    async fn update_followers(
1066        self: Arc<Self>,
1067        request: TypedEnvelope<proto::UpdateFollowers>,
1068    ) -> Result<()> {
1069        let connection_ids = self
1070            .store()
1071            .await
1072            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1073        let leader_id = request
1074            .payload
1075            .variant
1076            .as_ref()
1077            .and_then(|variant| match variant {
1078                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1079                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1080                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1081            });
1082        for follower_id in &request.payload.follower_ids {
1083            let follower_id = ConnectionId(*follower_id);
1084            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1085                self.peer
1086                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1087            }
1088        }
1089        Ok(())
1090    }
1091
1092    async fn get_channels(
1093        self: Arc<Server>,
1094        request: TypedEnvelope<proto::GetChannels>,
1095        response: Response<proto::GetChannels>,
1096    ) -> Result<()> {
1097        let user_id = self
1098            .store()
1099            .await
1100            .user_id_for_connection(request.sender_id)?;
1101        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1102        response.send(proto::GetChannelsResponse {
1103            channels: channels
1104                .into_iter()
1105                .map(|chan| proto::Channel {
1106                    id: chan.id.to_proto(),
1107                    name: chan.name,
1108                })
1109                .collect(),
1110        })?;
1111        Ok(())
1112    }
1113
1114    async fn get_users(
1115        self: Arc<Server>,
1116        request: TypedEnvelope<proto::GetUsers>,
1117        response: Response<proto::GetUsers>,
1118    ) -> Result<()> {
1119        let user_ids = request
1120            .payload
1121            .user_ids
1122            .into_iter()
1123            .map(UserId::from_proto)
1124            .collect();
1125        let users = self
1126            .app_state
1127            .db
1128            .get_users_by_ids(user_ids)
1129            .await?
1130            .into_iter()
1131            .map(|user| proto::User {
1132                id: user.id.to_proto(),
1133                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1134                github_login: user.github_login,
1135            })
1136            .collect();
1137        response.send(proto::UsersResponse { users })?;
1138        Ok(())
1139    }
1140
1141    async fn fuzzy_search_users(
1142        self: Arc<Server>,
1143        request: TypedEnvelope<proto::FuzzySearchUsers>,
1144        response: Response<proto::FuzzySearchUsers>,
1145    ) -> Result<()> {
1146        let user_id = self
1147            .store()
1148            .await
1149            .user_id_for_connection(request.sender_id)?;
1150        let query = request.payload.query;
1151        let db = &self.app_state.db;
1152        let users = match query.len() {
1153            0 => vec![],
1154            1 | 2 => db
1155                .get_user_by_github_login(&query)
1156                .await?
1157                .into_iter()
1158                .collect(),
1159            _ => db.fuzzy_search_users(&query, 10).await?,
1160        };
1161        let users = users
1162            .into_iter()
1163            .filter(|user| user.id != user_id)
1164            .map(|user| proto::User {
1165                id: user.id.to_proto(),
1166                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1167                github_login: user.github_login,
1168            })
1169            .collect();
1170        response.send(proto::UsersResponse { users })?;
1171        Ok(())
1172    }
1173
1174    async fn request_contact(
1175        self: Arc<Server>,
1176        request: TypedEnvelope<proto::RequestContact>,
1177        response: Response<proto::RequestContact>,
1178    ) -> Result<()> {
1179        let requester_id = self
1180            .store()
1181            .await
1182            .user_id_for_connection(request.sender_id)?;
1183        let responder_id = UserId::from_proto(request.payload.responder_id);
1184        if requester_id == responder_id {
1185            return Err(anyhow!("cannot add yourself as a contact"))?;
1186        }
1187
1188        self.app_state
1189            .db
1190            .send_contact_request(requester_id, responder_id)
1191            .await?;
1192
1193        // Update outgoing contact requests of requester
1194        let mut update = proto::UpdateContacts::default();
1195        update.outgoing_requests.push(responder_id.to_proto());
1196        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1197            self.peer.send(connection_id, update.clone())?;
1198        }
1199
1200        // Update incoming contact requests of responder
1201        let mut update = proto::UpdateContacts::default();
1202        update
1203            .incoming_requests
1204            .push(proto::IncomingContactRequest {
1205                requester_id: requester_id.to_proto(),
1206                should_notify: true,
1207            });
1208        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1209            self.peer.send(connection_id, update.clone())?;
1210        }
1211
1212        response.send(proto::Ack {})?;
1213        Ok(())
1214    }
1215
1216    async fn respond_to_contact_request(
1217        self: Arc<Server>,
1218        request: TypedEnvelope<proto::RespondToContactRequest>,
1219        response: Response<proto::RespondToContactRequest>,
1220    ) -> Result<()> {
1221        let responder_id = self
1222            .store()
1223            .await
1224            .user_id_for_connection(request.sender_id)?;
1225        let requester_id = UserId::from_proto(request.payload.requester_id);
1226        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1227            self.app_state
1228                .db
1229                .dismiss_contact_notification(responder_id, requester_id)
1230                .await?;
1231        } else {
1232            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1233            self.app_state
1234                .db
1235                .respond_to_contact_request(responder_id, requester_id, accept)
1236                .await?;
1237
1238            let store = self.store().await;
1239            // Update responder with new contact
1240            let mut update = proto::UpdateContacts::default();
1241            if accept {
1242                update
1243                    .contacts
1244                    .push(store.contact_for_user(requester_id, false));
1245            }
1246            update
1247                .remove_incoming_requests
1248                .push(requester_id.to_proto());
1249            for connection_id in store.connection_ids_for_user(responder_id) {
1250                self.peer.send(connection_id, update.clone())?;
1251            }
1252
1253            // Update requester with new contact
1254            let mut update = proto::UpdateContacts::default();
1255            if accept {
1256                update
1257                    .contacts
1258                    .push(store.contact_for_user(responder_id, true));
1259            }
1260            update
1261                .remove_outgoing_requests
1262                .push(responder_id.to_proto());
1263            for connection_id in store.connection_ids_for_user(requester_id) {
1264                self.peer.send(connection_id, update.clone())?;
1265            }
1266        }
1267
1268        response.send(proto::Ack {})?;
1269        Ok(())
1270    }
1271
1272    async fn remove_contact(
1273        self: Arc<Server>,
1274        request: TypedEnvelope<proto::RemoveContact>,
1275        response: Response<proto::RemoveContact>,
1276    ) -> Result<()> {
1277        let requester_id = self
1278            .store()
1279            .await
1280            .user_id_for_connection(request.sender_id)?;
1281        let responder_id = UserId::from_proto(request.payload.user_id);
1282        self.app_state
1283            .db
1284            .remove_contact(requester_id, responder_id)
1285            .await?;
1286
1287        // Update outgoing contact requests of requester
1288        let mut update = proto::UpdateContacts::default();
1289        update
1290            .remove_outgoing_requests
1291            .push(responder_id.to_proto());
1292        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1293            self.peer.send(connection_id, update.clone())?;
1294        }
1295
1296        // Update incoming contact requests of responder
1297        let mut update = proto::UpdateContacts::default();
1298        update
1299            .remove_incoming_requests
1300            .push(requester_id.to_proto());
1301        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1302            self.peer.send(connection_id, update.clone())?;
1303        }
1304
1305        response.send(proto::Ack {})?;
1306        Ok(())
1307    }
1308
1309    async fn join_channel(
1310        self: Arc<Self>,
1311        request: TypedEnvelope<proto::JoinChannel>,
1312        response: Response<proto::JoinChannel>,
1313    ) -> Result<()> {
1314        let user_id = self
1315            .store()
1316            .await
1317            .user_id_for_connection(request.sender_id)?;
1318        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1319        if !self
1320            .app_state
1321            .db
1322            .can_user_access_channel(user_id, channel_id)
1323            .await?
1324        {
1325            Err(anyhow!("access denied"))?;
1326        }
1327
1328        self.store_mut()
1329            .await
1330            .join_channel(request.sender_id, channel_id);
1331        let messages = self
1332            .app_state
1333            .db
1334            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1335            .await?
1336            .into_iter()
1337            .map(|msg| proto::ChannelMessage {
1338                id: msg.id.to_proto(),
1339                body: msg.body,
1340                timestamp: msg.sent_at.unix_timestamp() as u64,
1341                sender_id: msg.sender_id.to_proto(),
1342                nonce: Some(msg.nonce.as_u128().into()),
1343            })
1344            .collect::<Vec<_>>();
1345        response.send(proto::JoinChannelResponse {
1346            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1347            messages,
1348        })?;
1349        Ok(())
1350    }
1351
1352    async fn leave_channel(
1353        self: Arc<Self>,
1354        request: TypedEnvelope<proto::LeaveChannel>,
1355    ) -> Result<()> {
1356        let user_id = self
1357            .store()
1358            .await
1359            .user_id_for_connection(request.sender_id)?;
1360        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1361        if !self
1362            .app_state
1363            .db
1364            .can_user_access_channel(user_id, channel_id)
1365            .await?
1366        {
1367            Err(anyhow!("access denied"))?;
1368        }
1369
1370        self.store_mut()
1371            .await
1372            .leave_channel(request.sender_id, channel_id);
1373
1374        Ok(())
1375    }
1376
1377    async fn send_channel_message(
1378        self: Arc<Self>,
1379        request: TypedEnvelope<proto::SendChannelMessage>,
1380        response: Response<proto::SendChannelMessage>,
1381    ) -> Result<()> {
1382        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1383        let user_id;
1384        let connection_ids;
1385        {
1386            let state = self.store().await;
1387            user_id = state.user_id_for_connection(request.sender_id)?;
1388            connection_ids = state.channel_connection_ids(channel_id)?;
1389        }
1390
1391        // Validate the message body.
1392        let body = request.payload.body.trim().to_string();
1393        if body.len() > MAX_MESSAGE_LEN {
1394            return Err(anyhow!("message is too long"))?;
1395        }
1396        if body.is_empty() {
1397            return Err(anyhow!("message can't be blank"))?;
1398        }
1399
1400        let timestamp = OffsetDateTime::now_utc();
1401        let nonce = request
1402            .payload
1403            .nonce
1404            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1405
1406        let message_id = self
1407            .app_state
1408            .db
1409            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1410            .await?
1411            .to_proto();
1412        let message = proto::ChannelMessage {
1413            sender_id: user_id.to_proto(),
1414            id: message_id,
1415            body,
1416            timestamp: timestamp.unix_timestamp() as u64,
1417            nonce: Some(nonce),
1418        };
1419        broadcast(request.sender_id, connection_ids, |conn_id| {
1420            self.peer.send(
1421                conn_id,
1422                proto::ChannelMessageSent {
1423                    channel_id: channel_id.to_proto(),
1424                    message: Some(message.clone()),
1425                },
1426            )
1427        });
1428        response.send(proto::SendChannelMessageResponse {
1429            message: Some(message),
1430        })?;
1431        Ok(())
1432    }
1433
1434    async fn get_channel_messages(
1435        self: Arc<Self>,
1436        request: TypedEnvelope<proto::GetChannelMessages>,
1437        response: Response<proto::GetChannelMessages>,
1438    ) -> Result<()> {
1439        let user_id = self
1440            .store()
1441            .await
1442            .user_id_for_connection(request.sender_id)?;
1443        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1444        if !self
1445            .app_state
1446            .db
1447            .can_user_access_channel(user_id, channel_id)
1448            .await?
1449        {
1450            Err(anyhow!("access denied"))?;
1451        }
1452
1453        let messages = self
1454            .app_state
1455            .db
1456            .get_channel_messages(
1457                channel_id,
1458                MESSAGE_COUNT_PER_PAGE,
1459                Some(MessageId::from_proto(request.payload.before_message_id)),
1460            )
1461            .await?
1462            .into_iter()
1463            .map(|msg| proto::ChannelMessage {
1464                id: msg.id.to_proto(),
1465                body: msg.body,
1466                timestamp: msg.sent_at.unix_timestamp() as u64,
1467                sender_id: msg.sender_id.to_proto(),
1468                nonce: Some(msg.nonce.as_u128().into()),
1469            })
1470            .collect::<Vec<_>>();
1471        response.send(proto::GetChannelMessagesResponse {
1472            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1473            messages,
1474        })?;
1475        Ok(())
1476    }
1477
1478    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1479        #[cfg(test)]
1480        tokio::task::yield_now().await;
1481        let guard = self.store.read().await;
1482        #[cfg(test)]
1483        tokio::task::yield_now().await;
1484        StoreReadGuard {
1485            guard,
1486            _not_send: PhantomData,
1487        }
1488    }
1489
1490    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1491        #[cfg(test)]
1492        tokio::task::yield_now().await;
1493        let guard = self.store.write().await;
1494        #[cfg(test)]
1495        tokio::task::yield_now().await;
1496        StoreWriteGuard {
1497            guard,
1498            _not_send: PhantomData,
1499        }
1500    }
1501
1502    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1503        ServerSnapshot {
1504            store: self.store.read().await,
1505            peer: &self.peer,
1506        }
1507    }
1508}
1509
1510impl<'a> Deref for StoreReadGuard<'a> {
1511    type Target = Store;
1512
1513    fn deref(&self) -> &Self::Target {
1514        &*self.guard
1515    }
1516}
1517
1518impl<'a> Deref for StoreWriteGuard<'a> {
1519    type Target = Store;
1520
1521    fn deref(&self) -> &Self::Target {
1522        &*self.guard
1523    }
1524}
1525
1526impl<'a> DerefMut for StoreWriteGuard<'a> {
1527    fn deref_mut(&mut self) -> &mut Self::Target {
1528        &mut *self.guard
1529    }
1530}
1531
1532impl<'a> Drop for StoreWriteGuard<'a> {
1533    fn drop(&mut self) {
1534        #[cfg(test)]
1535        self.check_invariants();
1536
1537        let metrics = self.metrics();
1538        tracing::info!(
1539            connections = metrics.connections,
1540            registered_projects = metrics.registered_projects,
1541            shared_projects = metrics.shared_projects,
1542            "metrics"
1543        );
1544    }
1545}
1546
1547impl Executor for RealExecutor {
1548    type Sleep = Sleep;
1549
1550    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1551        tokio::task::spawn(future);
1552    }
1553
1554    fn sleep(&self, duration: Duration) -> Self::Sleep {
1555        tokio::time::sleep(duration)
1556    }
1557}
1558
1559fn broadcast<F>(
1560    sender_id: ConnectionId,
1561    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1562    mut f: F,
1563) where
1564    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1565{
1566    for receiver_id in receiver_ids {
1567        if receiver_id != sender_id {
1568            f(receiver_id).trace_err();
1569        }
1570    }
1571}
1572
1573lazy_static! {
1574    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1575}
1576
1577pub struct ProtocolVersion(u32);
1578
1579impl Header for ProtocolVersion {
1580    fn name() -> &'static HeaderName {
1581        &ZED_PROTOCOL_VERSION
1582    }
1583
1584    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1585    where
1586        Self: Sized,
1587        I: Iterator<Item = &'i axum::http::HeaderValue>,
1588    {
1589        let version = values
1590            .next()
1591            .ok_or_else(|| axum::headers::Error::invalid())?
1592            .to_str()
1593            .map_err(|_| axum::headers::Error::invalid())?
1594            .parse()
1595            .map_err(|_| axum::headers::Error::invalid())?;
1596        Ok(Self(version))
1597    }
1598
1599    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1600        values.extend([self.0.to_string().parse().unwrap()]);
1601    }
1602}
1603
1604pub fn routes(server: Arc<Server>) -> Router<Body> {
1605    Router::new()
1606        .route("/rpc", get(handle_websocket_request))
1607        .layer(
1608            ServiceBuilder::new()
1609                .layer(Extension(server.app_state.clone()))
1610                .layer(middleware::from_fn(auth::validate_header))
1611                .layer(Extension(server)),
1612        )
1613}
1614
1615pub async fn handle_websocket_request(
1616    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1617    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1618    Extension(server): Extension<Arc<Server>>,
1619    Extension(user): Extension<User>,
1620    ws: WebSocketUpgrade,
1621) -> axum::response::Response {
1622    if protocol_version != rpc::PROTOCOL_VERSION {
1623        return (
1624            StatusCode::UPGRADE_REQUIRED,
1625            "client must be upgraded".to_string(),
1626        )
1627            .into_response();
1628    }
1629    let socket_address = socket_address.to_string();
1630    ws.on_upgrade(move |socket| {
1631        use util::ResultExt;
1632        let socket = socket
1633            .map_ok(to_tungstenite_message)
1634            .err_into()
1635            .with(|message| async move { Ok(to_axum_message(message)) });
1636        let connection = Connection::new(Box::pin(socket));
1637        async move {
1638            server
1639                .handle_connection(connection, socket_address, user, None, RealExecutor)
1640                .await
1641                .log_err();
1642        }
1643    })
1644}
1645
1646fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1647    match message {
1648        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1649        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1650        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1651        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1652        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1653            code: frame.code.into(),
1654            reason: frame.reason,
1655        })),
1656    }
1657}
1658
1659fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1660    match message {
1661        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1662        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1663        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1664        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1665        AxumMessage::Close(frame) => {
1666            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1667                code: frame.code.into(),
1668                reason: frame.reason,
1669            }))
1670        }
1671    }
1672}
1673
1674pub trait ResultExt {
1675    type Ok;
1676
1677    fn trace_err(self) -> Option<Self::Ok>;
1678}
1679
1680impl<T, E> ResultExt for Result<T, E>
1681where
1682    E: std::fmt::Debug,
1683{
1684    type Ok = T;
1685
1686    fn trace_err(self) -> Option<T> {
1687        match self {
1688            Ok(value) => Some(value),
1689            Err(error) => {
1690                tracing::error!("{:?}", error);
1691                None
1692            }
1693        }
1694    }
1695}