rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use prometheus::{register_int_gauge, IntGauge};
  33use rpc::{
  34    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  35    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  36};
  37use serde::{Serialize, Serializer};
  38use std::{
  39    any::TypeId,
  40    future::Future,
  41    marker::PhantomData,
  42    net::SocketAddr,
  43    ops::{Deref, DerefMut},
  44    rc::Rc,
  45    sync::{
  46        atomic::{AtomicBool, Ordering::SeqCst},
  47        Arc,
  48    },
  49    time::Duration,
  50};
  51use time::OffsetDateTime;
  52use tokio::{
  53    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  54    time::Sleep,
  55};
  56use tower::ServiceBuilder;
  57use tracing::{info_span, instrument, Instrument};
  58
  59pub use store::{Store, Worktree};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_PROJECTS: IntGauge =
  65        register_int_gauge!("projects", "number of open projects").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    server: Arc<Server>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.server.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88
  89    fn into_receipt(self) -> Receipt<R> {
  90        self.responded.store(true, SeqCst);
  91        self.receipt
  92    }
  93}
  94
  95pub struct Server {
  96    peer: Arc<Peer>,
  97    pub(crate) store: RwLock<Store>,
  98    app_state: Arc<AppState>,
  99    handlers: HashMap<TypeId, MessageHandler>,
 100    notifications: Option<mpsc::UnboundedSender<()>>,
 101}
 102
 103pub trait Executor: Send + Clone {
 104    type Sleep: Send + Future;
 105    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 106    fn sleep(&self, duration: Duration) -> Self::Sleep;
 107}
 108
 109#[derive(Clone)]
 110pub struct RealExecutor;
 111
 112const MESSAGE_COUNT_PER_PAGE: usize = 100;
 113const MAX_MESSAGE_LEN: usize = 1024;
 114
 115struct StoreReadGuard<'a> {
 116    guard: RwLockReadGuard<'a, Store>,
 117    _not_send: PhantomData<Rc<()>>,
 118}
 119
 120struct StoreWriteGuard<'a> {
 121    guard: RwLockWriteGuard<'a, Store>,
 122    _not_send: PhantomData<Rc<()>>,
 123}
 124
 125#[derive(Serialize)]
 126pub struct ServerSnapshot<'a> {
 127    peer: &'a Peer,
 128    #[serde(serialize_with = "serialize_deref")]
 129    store: RwLockReadGuard<'a, Store>,
 130}
 131
 132pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 133where
 134    S: Serializer,
 135    T: Deref<Target = U>,
 136    U: Serialize,
 137{
 138    Serialize::serialize(value.deref(), serializer)
 139}
 140
 141impl Server {
 142    pub fn new(
 143        app_state: Arc<AppState>,
 144        notifications: Option<mpsc::UnboundedSender<()>>,
 145    ) -> Arc<Self> {
 146        let mut server = Self {
 147            peer: Peer::new(),
 148            app_state,
 149            store: Default::default(),
 150            handlers: Default::default(),
 151            notifications,
 152        };
 153
 154        server
 155            .add_request_handler(Server::ping)
 156            .add_request_handler(Server::register_project)
 157            .add_request_handler(Server::unregister_project)
 158            .add_request_handler(Server::join_project)
 159            .add_message_handler(Server::leave_project)
 160            .add_message_handler(Server::respond_to_join_project_request)
 161            .add_message_handler(Server::update_project)
 162            .add_request_handler(Server::update_worktree)
 163            .add_message_handler(Server::start_language_server)
 164            .add_message_handler(Server::update_language_server)
 165            .add_message_handler(Server::update_diagnostic_summary)
 166            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 167            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 169            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 170            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 172            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 173            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 174            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 175            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 176            .add_request_handler(
 177                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 178            )
 179            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 180            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 181            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 182            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 183            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 184            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 185            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 186            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 187            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 188            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 189            .add_request_handler(Server::update_buffer)
 190            .add_message_handler(Server::update_buffer_file)
 191            .add_message_handler(Server::buffer_reloaded)
 192            .add_message_handler(Server::buffer_saved)
 193            .add_request_handler(Server::save_buffer)
 194            .add_request_handler(Server::get_channels)
 195            .add_request_handler(Server::get_users)
 196            .add_request_handler(Server::fuzzy_search_users)
 197            .add_request_handler(Server::request_contact)
 198            .add_request_handler(Server::remove_contact)
 199            .add_request_handler(Server::respond_to_contact_request)
 200            .add_request_handler(Server::join_channel)
 201            .add_message_handler(Server::leave_channel)
 202            .add_request_handler(Server::send_channel_message)
 203            .add_request_handler(Server::follow)
 204            .add_message_handler(Server::unfollow)
 205            .add_message_handler(Server::update_followers)
 206            .add_request_handler(Server::get_channel_messages);
 207
 208        Arc::new(server)
 209    }
 210
 211    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 212    where
 213        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 214        Fut: 'static + Send + Future<Output = Result<()>>,
 215        M: EnvelopedMessage,
 216    {
 217        let prev_handler = self.handlers.insert(
 218            TypeId::of::<M>(),
 219            Box::new(move |server, envelope| {
 220                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 221                let span = info_span!(
 222                    "handle message",
 223                    payload_type = envelope.payload_type_name(),
 224                    payload = format!("{:?}", envelope.payload).as_str(),
 225                );
 226                let future = (handler)(server, *envelope);
 227                async move {
 228                    if let Err(error) = future.await {
 229                        tracing::error!(%error, "error handling message");
 230                    }
 231                }
 232                .instrument(span)
 233                .boxed()
 234            }),
 235        );
 236        if prev_handler.is_some() {
 237            panic!("registered a handler for the same message twice");
 238        }
 239        self
 240    }
 241
 242    /// Handle a request while holding a lock to the store. This is useful when we're registering
 243    /// a connection but we want to respond on the connection before anybody else can send on it.
 244    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 245    where
 246        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 247        Fut: Send + Future<Output = Result<()>>,
 248        M: RequestMessage,
 249    {
 250        let handler = Arc::new(handler);
 251        self.add_message_handler(move |server, envelope| {
 252            let receipt = envelope.receipt();
 253            let handler = handler.clone();
 254            async move {
 255                let responded = Arc::new(AtomicBool::default());
 256                let response = Response {
 257                    server: server.clone(),
 258                    responded: responded.clone(),
 259                    receipt: envelope.receipt(),
 260                };
 261                match (handler)(server.clone(), envelope, response).await {
 262                    Ok(()) => {
 263                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 264                            Ok(())
 265                        } else {
 266                            Err(anyhow!("handler did not send a response"))?
 267                        }
 268                    }
 269                    Err(error) => {
 270                        server.peer.respond_with_error(
 271                            receipt,
 272                            proto::Error {
 273                                message: error.to_string(),
 274                            },
 275                        )?;
 276                        Err(error)
 277                    }
 278                }
 279            }
 280        })
 281    }
 282
 283    pub fn handle_connection<E: Executor>(
 284        self: &Arc<Self>,
 285        connection: Connection,
 286        address: String,
 287        user: User,
 288        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 289        executor: E,
 290    ) -> impl Future<Output = Result<()>> {
 291        let mut this = self.clone();
 292        let user_id = user.id;
 293        let login = user.github_login;
 294        let span = info_span!("handle connection", %user_id, %login, %address);
 295        async move {
 296            let (connection_id, handle_io, mut incoming_rx) = this
 297                .peer
 298                .add_connection(connection, {
 299                    let executor = executor.clone();
 300                    move |duration| {
 301                        let timer = executor.sleep(duration);
 302                        async move {
 303                            timer.await;
 304                        }
 305                    }
 306                })
 307                .await;
 308
 309            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 310
 311            if let Some(send_connection_id) = send_connection_id.as_mut() {
 312                let _ = send_connection_id.send(connection_id).await;
 313            }
 314
 315            if !user.connected_once {
 316                this.peer.send(connection_id, proto::ShowContacts {})?;
 317                this.app_state.db.set_user_connected_once(user_id, true).await?;
 318            }
 319
 320            let (contacts, invite_code) = future::try_join(
 321                this.app_state.db.get_contacts(user_id),
 322                this.app_state.db.get_invite_code_for_user(user_id)
 323            ).await?;
 324
 325            {
 326                let mut store = this.store_mut().await;
 327                store.add_connection(connection_id, user_id);
 328                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 329
 330                if let Some((code, count)) = invite_code {
 331                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 332                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 333                        count,
 334                    })?;
 335                }
 336            }
 337            this.update_user_contacts(user_id).await?;
 338
 339            let handle_io = handle_io.fuse();
 340            futures::pin_mut!(handle_io);
 341            loop {
 342                let next_message = incoming_rx.next().fuse();
 343                futures::pin_mut!(next_message);
 344                futures::select_biased! {
 345                    result = handle_io => {
 346                        if let Err(error) = result {
 347                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 348                        }
 349                        break;
 350                    }
 351                    message = next_message => {
 352                        if let Some(message) = message {
 353                            let type_name = message.payload_type_name();
 354                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 355                            async {
 356                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 357                                    let notifications = this.notifications.clone();
 358                                    let is_background = message.is_background();
 359                                    let handle_message = (handler)(this.clone(), message);
 360                                    let handle_message = async move {
 361                                        handle_message.await;
 362                                        if let Some(mut notifications) = notifications {
 363                                            let _ = notifications.send(()).await;
 364                                        }
 365                                    };
 366                                    if is_background {
 367                                        executor.spawn_detached(handle_message);
 368                                    } else {
 369                                        handle_message.await;
 370                                    }
 371                                } else {
 372                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 373                                }
 374                            }.instrument(span).await;
 375                        } else {
 376                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 377                            break;
 378                        }
 379                    }
 380                }
 381            }
 382
 383            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 384            if let Err(error) = this.sign_out(connection_id).await {
 385                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 386            }
 387
 388            Ok(())
 389        }.instrument(span)
 390    }
 391
 392    #[instrument(skip(self), err)]
 393    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 394        self.peer.disconnect(connection_id);
 395
 396        let removed_user_id = {
 397            let mut store = self.store_mut().await;
 398            let removed_connection = store.remove_connection(connection_id)?;
 399
 400            for (project_id, project) in removed_connection.hosted_projects {
 401                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 402                    self.peer
 403                        .send(conn_id, proto::UnregisterProject { project_id })
 404                });
 405
 406                for (_, receipts) in project.join_requests {
 407                    for receipt in receipts {
 408                        self.peer.respond(
 409                            receipt,
 410                            proto::JoinProjectResponse {
 411                                variant: Some(proto::join_project_response::Variant::Decline(
 412                                    proto::join_project_response::Decline {
 413                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 414                                    },
 415                                )),
 416                            },
 417                        )?;
 418                    }
 419                }
 420            }
 421
 422            for project_id in removed_connection.guest_project_ids {
 423                if let Some(project) = store.project(project_id).trace_err() {
 424                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 425                        self.peer.send(
 426                            conn_id,
 427                            proto::RemoveProjectCollaborator {
 428                                project_id,
 429                                peer_id: connection_id.0,
 430                            },
 431                        )
 432                    });
 433                    if project.guests.is_empty() {
 434                        self.peer
 435                            .send(
 436                                project.host_connection_id,
 437                                proto::ProjectUnshared { project_id },
 438                            )
 439                            .trace_err();
 440                    }
 441                }
 442            }
 443
 444            removed_connection.user_id
 445        };
 446
 447        self.update_user_contacts(removed_user_id).await?;
 448
 449        Ok(())
 450    }
 451
 452    pub async fn invite_code_redeemed(
 453        self: &Arc<Self>,
 454        code: &str,
 455        invitee_id: UserId,
 456    ) -> Result<()> {
 457        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 458        let store = self.store().await;
 459        let invitee_contact = store.contact_for_user(invitee_id, true);
 460        for connection_id in store.connection_ids_for_user(user.id) {
 461            self.peer.send(
 462                connection_id,
 463                proto::UpdateContacts {
 464                    contacts: vec![invitee_contact.clone()],
 465                    ..Default::default()
 466                },
 467            )?;
 468            self.peer.send(
 469                connection_id,
 470                proto::UpdateInviteInfo {
 471                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 472                    count: user.invite_count as u32,
 473                },
 474            )?;
 475        }
 476        Ok(())
 477    }
 478
 479    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 480        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 481            if let Some(invite_code) = &user.invite_code {
 482                let store = self.store().await;
 483                for connection_id in store.connection_ids_for_user(user_id) {
 484                    self.peer.send(
 485                        connection_id,
 486                        proto::UpdateInviteInfo {
 487                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 488                            count: user.invite_count as u32,
 489                        },
 490                    )?;
 491                }
 492            }
 493        }
 494        Ok(())
 495    }
 496
 497    async fn ping(
 498        self: Arc<Server>,
 499        _: TypedEnvelope<proto::Ping>,
 500        response: Response<proto::Ping>,
 501    ) -> Result<()> {
 502        response.send(proto::Ack {})?;
 503        Ok(())
 504    }
 505
 506    async fn register_project(
 507        self: Arc<Server>,
 508        request: TypedEnvelope<proto::RegisterProject>,
 509        response: Response<proto::RegisterProject>,
 510    ) -> Result<()> {
 511        let project_id;
 512        {
 513            let mut state = self.store_mut().await;
 514            let user_id = state.user_id_for_connection(request.sender_id)?;
 515            project_id = state.register_project(request.sender_id, user_id);
 516        };
 517
 518        response.send(proto::RegisterProjectResponse { project_id })?;
 519
 520        Ok(())
 521    }
 522
 523    async fn unregister_project(
 524        self: Arc<Server>,
 525        request: TypedEnvelope<proto::UnregisterProject>,
 526        response: Response<proto::UnregisterProject>,
 527    ) -> Result<()> {
 528        let (user_id, project) = {
 529            let mut state = self.store_mut().await;
 530            let project =
 531                state.unregister_project(request.payload.project_id, request.sender_id)?;
 532            (state.user_id_for_connection(request.sender_id)?, project)
 533        };
 534
 535        broadcast(
 536            request.sender_id,
 537            project.guests.keys().copied(),
 538            |conn_id| {
 539                self.peer.send(
 540                    conn_id,
 541                    proto::UnregisterProject {
 542                        project_id: request.payload.project_id,
 543                    },
 544                )
 545            },
 546        );
 547        for (_, receipts) in project.join_requests {
 548            for receipt in receipts {
 549                self.peer.respond(
 550                    receipt,
 551                    proto::JoinProjectResponse {
 552                        variant: Some(proto::join_project_response::Variant::Decline(
 553                            proto::join_project_response::Decline {
 554                                reason: proto::join_project_response::decline::Reason::Closed
 555                                    as i32,
 556                            },
 557                        )),
 558                    },
 559                )?;
 560            }
 561        }
 562
 563        // Send out the `UpdateContacts` message before responding to the unregister
 564        // request. This way, when the project's host can keep track of the project's
 565        // remote id until after they've received the `UpdateContacts` message for
 566        // themself.
 567        self.update_user_contacts(user_id).await?;
 568        response.send(proto::Ack {})?;
 569
 570        Ok(())
 571    }
 572
 573    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 574        let contacts = self.app_state.db.get_contacts(user_id).await?;
 575        let store = self.store().await;
 576        let updated_contact = store.contact_for_user(user_id, false);
 577        for contact in contacts {
 578            if let db::Contact::Accepted {
 579                user_id: contact_user_id,
 580                ..
 581            } = contact
 582            {
 583                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 584                    self.peer
 585                        .send(
 586                            contact_conn_id,
 587                            proto::UpdateContacts {
 588                                contacts: vec![updated_contact.clone()],
 589                                remove_contacts: Default::default(),
 590                                incoming_requests: Default::default(),
 591                                remove_incoming_requests: Default::default(),
 592                                outgoing_requests: Default::default(),
 593                                remove_outgoing_requests: Default::default(),
 594                            },
 595                        )
 596                        .trace_err();
 597                }
 598            }
 599        }
 600        Ok(())
 601    }
 602
 603    async fn join_project(
 604        self: Arc<Server>,
 605        request: TypedEnvelope<proto::JoinProject>,
 606        response: Response<proto::JoinProject>,
 607    ) -> Result<()> {
 608        let project_id = request.payload.project_id;
 609
 610        let host_user_id;
 611        let guest_user_id;
 612        let host_connection_id;
 613        {
 614            let state = self.store().await;
 615            let project = state.project(project_id)?;
 616            host_user_id = project.host_user_id;
 617            host_connection_id = project.host_connection_id;
 618            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 619        };
 620
 621        let has_contact = self
 622            .app_state
 623            .db
 624            .has_contact(guest_user_id, host_user_id)
 625            .await?;
 626        if !has_contact {
 627            return Err(anyhow!("no such project"))?;
 628        }
 629
 630        self.store_mut().await.request_join_project(
 631            guest_user_id,
 632            project_id,
 633            response.into_receipt(),
 634        )?;
 635        self.peer.send(
 636            host_connection_id,
 637            proto::RequestJoinProject {
 638                project_id,
 639                requester_id: guest_user_id.to_proto(),
 640            },
 641        )?;
 642        Ok(())
 643    }
 644
 645    async fn respond_to_join_project_request(
 646        self: Arc<Server>,
 647        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 648    ) -> Result<()> {
 649        let host_user_id;
 650
 651        {
 652            let mut state = self.store_mut().await;
 653            let project_id = request.payload.project_id;
 654            let project = state.project(project_id)?;
 655            if project.host_connection_id != request.sender_id {
 656                Err(anyhow!("no such connection"))?;
 657            }
 658
 659            host_user_id = project.host_user_id;
 660            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 661
 662            if !request.payload.allow {
 663                let receipts = state
 664                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 665                    .ok_or_else(|| anyhow!("no such request"))?;
 666                for receipt in receipts {
 667                    self.peer.respond(
 668                        receipt,
 669                        proto::JoinProjectResponse {
 670                            variant: Some(proto::join_project_response::Variant::Decline(
 671                                proto::join_project_response::Decline {
 672                                    reason: proto::join_project_response::decline::Reason::Declined
 673                                        as i32,
 674                                },
 675                            )),
 676                        },
 677                    )?;
 678                }
 679                return Ok(());
 680            }
 681
 682            let (receipts_with_replica_ids, project) = state
 683                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 684                .ok_or_else(|| anyhow!("no such request"))?;
 685
 686            let peer_count = project.guests.len();
 687            let mut collaborators = Vec::with_capacity(peer_count);
 688            collaborators.push(proto::Collaborator {
 689                peer_id: project.host_connection_id.0,
 690                replica_id: 0,
 691                user_id: project.host_user_id.to_proto(),
 692            });
 693            let worktrees = project
 694                .worktrees
 695                .iter()
 696                .filter_map(|(id, shared_worktree)| {
 697                    let worktree = project.worktrees.get(&id)?;
 698                    Some(proto::Worktree {
 699                        id: *id,
 700                        root_name: worktree.root_name.clone(),
 701                        entries: shared_worktree.entries.values().cloned().collect(),
 702                        diagnostic_summaries: shared_worktree
 703                            .diagnostic_summaries
 704                            .values()
 705                            .cloned()
 706                            .collect(),
 707                        visible: worktree.visible,
 708                        scan_id: shared_worktree.scan_id,
 709                    })
 710                })
 711                .collect::<Vec<_>>();
 712
 713            // Add all guests other than the requesting user's own connections as collaborators
 714            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 715                if receipts_with_replica_ids
 716                    .iter()
 717                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 718                {
 719                    collaborators.push(proto::Collaborator {
 720                        peer_id: peer_conn_id.0,
 721                        replica_id: *peer_replica_id as u32,
 722                        user_id: peer_user_id.to_proto(),
 723                    });
 724                }
 725            }
 726
 727            for conn_id in project.connection_ids() {
 728                for (receipt, replica_id) in &receipts_with_replica_ids {
 729                    if conn_id != receipt.sender_id {
 730                        self.peer.send(
 731                            conn_id,
 732                            proto::AddProjectCollaborator {
 733                                project_id,
 734                                collaborator: Some(proto::Collaborator {
 735                                    peer_id: receipt.sender_id.0,
 736                                    replica_id: *replica_id as u32,
 737                                    user_id: guest_user_id.to_proto(),
 738                                }),
 739                            },
 740                        )?;
 741                    }
 742                }
 743            }
 744
 745            for (receipt, replica_id) in receipts_with_replica_ids {
 746                self.peer.respond(
 747                    receipt,
 748                    proto::JoinProjectResponse {
 749                        variant: Some(proto::join_project_response::Variant::Accept(
 750                            proto::join_project_response::Accept {
 751                                worktrees: worktrees.clone(),
 752                                replica_id: replica_id as u32,
 753                                collaborators: collaborators.clone(),
 754                                language_servers: project.language_servers.clone(),
 755                            },
 756                        )),
 757                    },
 758                )?;
 759            }
 760        }
 761
 762        self.update_user_contacts(host_user_id).await?;
 763        Ok(())
 764    }
 765
 766    async fn leave_project(
 767        self: Arc<Server>,
 768        request: TypedEnvelope<proto::LeaveProject>,
 769    ) -> Result<()> {
 770        let sender_id = request.sender_id;
 771        let project_id = request.payload.project_id;
 772        let project;
 773        {
 774            let mut store = self.store_mut().await;
 775            project = store.leave_project(sender_id, project_id)?;
 776
 777            if project.remove_collaborator {
 778                broadcast(sender_id, project.connection_ids, |conn_id| {
 779                    self.peer.send(
 780                        conn_id,
 781                        proto::RemoveProjectCollaborator {
 782                            project_id,
 783                            peer_id: sender_id.0,
 784                        },
 785                    )
 786                });
 787            }
 788
 789            if let Some(requester_id) = project.cancel_request {
 790                self.peer.send(
 791                    project.host_connection_id,
 792                    proto::JoinProjectRequestCancelled {
 793                        project_id,
 794                        requester_id: requester_id.to_proto(),
 795                    },
 796                )?;
 797            }
 798
 799            if project.unshare {
 800                self.peer.send(
 801                    project.host_connection_id,
 802                    proto::ProjectUnshared { project_id },
 803                )?;
 804            }
 805        }
 806        self.update_user_contacts(project.host_user_id).await?;
 807        Ok(())
 808    }
 809
 810    async fn update_project(
 811        self: Arc<Server>,
 812        request: TypedEnvelope<proto::UpdateProject>,
 813    ) -> Result<()> {
 814        let user_id;
 815        {
 816            let mut state = self.store_mut().await;
 817            user_id = state.user_id_for_connection(request.sender_id)?;
 818            let guest_connection_ids = state
 819                .read_project(request.payload.project_id, request.sender_id)?
 820                .guest_connection_ids();
 821            state.update_project(
 822                request.payload.project_id,
 823                &request.payload.worktrees,
 824                request.sender_id,
 825            )?;
 826            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 827                self.peer
 828                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 829            });
 830        };
 831        self.update_user_contacts(user_id).await?;
 832        Ok(())
 833    }
 834
 835    async fn update_worktree(
 836        self: Arc<Server>,
 837        request: TypedEnvelope<proto::UpdateWorktree>,
 838        response: Response<proto::UpdateWorktree>,
 839    ) -> Result<()> {
 840        let (connection_ids, metadata_changed) = self.store_mut().await.update_worktree(
 841            request.sender_id,
 842            request.payload.project_id,
 843            request.payload.worktree_id,
 844            &request.payload.root_name,
 845            &request.payload.removed_entries,
 846            &request.payload.updated_entries,
 847            request.payload.scan_id,
 848        )?;
 849
 850        broadcast(request.sender_id, connection_ids, |connection_id| {
 851            self.peer
 852                .forward_send(request.sender_id, connection_id, request.payload.clone())
 853        });
 854        if metadata_changed {
 855            let user_id = self
 856                .store()
 857                .await
 858                .user_id_for_connection(request.sender_id)?;
 859            self.update_user_contacts(user_id).await?;
 860        }
 861        response.send(proto::Ack {})?;
 862        Ok(())
 863    }
 864
 865    async fn update_diagnostic_summary(
 866        self: Arc<Server>,
 867        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 868    ) -> Result<()> {
 869        let summary = request
 870            .payload
 871            .summary
 872            .clone()
 873            .ok_or_else(|| anyhow!("invalid summary"))?;
 874        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 875            request.payload.project_id,
 876            request.payload.worktree_id,
 877            request.sender_id,
 878            summary,
 879        )?;
 880
 881        broadcast(request.sender_id, receiver_ids, |connection_id| {
 882            self.peer
 883                .forward_send(request.sender_id, connection_id, request.payload.clone())
 884        });
 885        Ok(())
 886    }
 887
 888    async fn start_language_server(
 889        self: Arc<Server>,
 890        request: TypedEnvelope<proto::StartLanguageServer>,
 891    ) -> Result<()> {
 892        let receiver_ids = self.store_mut().await.start_language_server(
 893            request.payload.project_id,
 894            request.sender_id,
 895            request
 896                .payload
 897                .server
 898                .clone()
 899                .ok_or_else(|| anyhow!("invalid language server"))?,
 900        )?;
 901        broadcast(request.sender_id, receiver_ids, |connection_id| {
 902            self.peer
 903                .forward_send(request.sender_id, connection_id, request.payload.clone())
 904        });
 905        Ok(())
 906    }
 907
 908    async fn update_language_server(
 909        self: Arc<Server>,
 910        request: TypedEnvelope<proto::UpdateLanguageServer>,
 911    ) -> Result<()> {
 912        let receiver_ids = self
 913            .store()
 914            .await
 915            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 916        broadcast(request.sender_id, receiver_ids, |connection_id| {
 917            self.peer
 918                .forward_send(request.sender_id, connection_id, request.payload.clone())
 919        });
 920        Ok(())
 921    }
 922
 923    async fn forward_project_request<T>(
 924        self: Arc<Server>,
 925        request: TypedEnvelope<T>,
 926        response: Response<T>,
 927    ) -> Result<()>
 928    where
 929        T: EntityMessage + RequestMessage,
 930    {
 931        let host_connection_id = self
 932            .store()
 933            .await
 934            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 935            .host_connection_id;
 936
 937        response.send(
 938            self.peer
 939                .forward_request(request.sender_id, host_connection_id, request.payload)
 940                .await?,
 941        )?;
 942        Ok(())
 943    }
 944
 945    async fn save_buffer(
 946        self: Arc<Server>,
 947        request: TypedEnvelope<proto::SaveBuffer>,
 948        response: Response<proto::SaveBuffer>,
 949    ) -> Result<()> {
 950        let host = self
 951            .store()
 952            .await
 953            .read_project(request.payload.project_id, request.sender_id)?
 954            .host_connection_id;
 955        let response_payload = self
 956            .peer
 957            .forward_request(request.sender_id, host, request.payload.clone())
 958            .await?;
 959
 960        let mut guests = self
 961            .store()
 962            .await
 963            .read_project(request.payload.project_id, request.sender_id)?
 964            .connection_ids();
 965        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 966        broadcast(host, guests, |conn_id| {
 967            self.peer
 968                .forward_send(host, conn_id, response_payload.clone())
 969        });
 970        response.send(response_payload)?;
 971        Ok(())
 972    }
 973
 974    async fn update_buffer(
 975        self: Arc<Server>,
 976        request: TypedEnvelope<proto::UpdateBuffer>,
 977        response: Response<proto::UpdateBuffer>,
 978    ) -> Result<()> {
 979        let receiver_ids = self
 980            .store()
 981            .await
 982            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 983        broadcast(request.sender_id, receiver_ids, |connection_id| {
 984            self.peer
 985                .forward_send(request.sender_id, connection_id, request.payload.clone())
 986        });
 987        response.send(proto::Ack {})?;
 988        Ok(())
 989    }
 990
 991    async fn update_buffer_file(
 992        self: Arc<Server>,
 993        request: TypedEnvelope<proto::UpdateBufferFile>,
 994    ) -> Result<()> {
 995        let receiver_ids = self
 996            .store()
 997            .await
 998            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 999        broadcast(request.sender_id, receiver_ids, |connection_id| {
1000            self.peer
1001                .forward_send(request.sender_id, connection_id, request.payload.clone())
1002        });
1003        Ok(())
1004    }
1005
1006    async fn buffer_reloaded(
1007        self: Arc<Server>,
1008        request: TypedEnvelope<proto::BufferReloaded>,
1009    ) -> Result<()> {
1010        let receiver_ids = self
1011            .store()
1012            .await
1013            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1014        broadcast(request.sender_id, receiver_ids, |connection_id| {
1015            self.peer
1016                .forward_send(request.sender_id, connection_id, request.payload.clone())
1017        });
1018        Ok(())
1019    }
1020
1021    async fn buffer_saved(
1022        self: Arc<Server>,
1023        request: TypedEnvelope<proto::BufferSaved>,
1024    ) -> Result<()> {
1025        let receiver_ids = self
1026            .store()
1027            .await
1028            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1029        broadcast(request.sender_id, receiver_ids, |connection_id| {
1030            self.peer
1031                .forward_send(request.sender_id, connection_id, request.payload.clone())
1032        });
1033        Ok(())
1034    }
1035
1036    async fn follow(
1037        self: Arc<Self>,
1038        request: TypedEnvelope<proto::Follow>,
1039        response: Response<proto::Follow>,
1040    ) -> Result<()> {
1041        let leader_id = ConnectionId(request.payload.leader_id);
1042        let follower_id = request.sender_id;
1043        if !self
1044            .store()
1045            .await
1046            .project_connection_ids(request.payload.project_id, follower_id)?
1047            .contains(&leader_id)
1048        {
1049            Err(anyhow!("no such peer"))?;
1050        }
1051        let mut response_payload = self
1052            .peer
1053            .forward_request(request.sender_id, leader_id, request.payload)
1054            .await?;
1055        response_payload
1056            .views
1057            .retain(|view| view.leader_id != Some(follower_id.0));
1058        response.send(response_payload)?;
1059        Ok(())
1060    }
1061
1062    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1063        let leader_id = ConnectionId(request.payload.leader_id);
1064        if !self
1065            .store()
1066            .await
1067            .project_connection_ids(request.payload.project_id, request.sender_id)?
1068            .contains(&leader_id)
1069        {
1070            Err(anyhow!("no such peer"))?;
1071        }
1072        self.peer
1073            .forward_send(request.sender_id, leader_id, request.payload)?;
1074        Ok(())
1075    }
1076
1077    async fn update_followers(
1078        self: Arc<Self>,
1079        request: TypedEnvelope<proto::UpdateFollowers>,
1080    ) -> Result<()> {
1081        let connection_ids = self
1082            .store()
1083            .await
1084            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1085        let leader_id = request
1086            .payload
1087            .variant
1088            .as_ref()
1089            .and_then(|variant| match variant {
1090                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1091                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1092                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1093            });
1094        for follower_id in &request.payload.follower_ids {
1095            let follower_id = ConnectionId(*follower_id);
1096            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1097                self.peer
1098                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1099            }
1100        }
1101        Ok(())
1102    }
1103
1104    async fn get_channels(
1105        self: Arc<Server>,
1106        request: TypedEnvelope<proto::GetChannels>,
1107        response: Response<proto::GetChannels>,
1108    ) -> Result<()> {
1109        let user_id = self
1110            .store()
1111            .await
1112            .user_id_for_connection(request.sender_id)?;
1113        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1114        response.send(proto::GetChannelsResponse {
1115            channels: channels
1116                .into_iter()
1117                .map(|chan| proto::Channel {
1118                    id: chan.id.to_proto(),
1119                    name: chan.name,
1120                })
1121                .collect(),
1122        })?;
1123        Ok(())
1124    }
1125
1126    async fn get_users(
1127        self: Arc<Server>,
1128        request: TypedEnvelope<proto::GetUsers>,
1129        response: Response<proto::GetUsers>,
1130    ) -> Result<()> {
1131        let user_ids = request
1132            .payload
1133            .user_ids
1134            .into_iter()
1135            .map(UserId::from_proto)
1136            .collect();
1137        let users = self
1138            .app_state
1139            .db
1140            .get_users_by_ids(user_ids)
1141            .await?
1142            .into_iter()
1143            .map(|user| proto::User {
1144                id: user.id.to_proto(),
1145                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1146                github_login: user.github_login,
1147            })
1148            .collect();
1149        response.send(proto::UsersResponse { users })?;
1150        Ok(())
1151    }
1152
1153    async fn fuzzy_search_users(
1154        self: Arc<Server>,
1155        request: TypedEnvelope<proto::FuzzySearchUsers>,
1156        response: Response<proto::FuzzySearchUsers>,
1157    ) -> Result<()> {
1158        let user_id = self
1159            .store()
1160            .await
1161            .user_id_for_connection(request.sender_id)?;
1162        let query = request.payload.query;
1163        let db = &self.app_state.db;
1164        let users = match query.len() {
1165            0 => vec![],
1166            1 | 2 => db
1167                .get_user_by_github_login(&query)
1168                .await?
1169                .into_iter()
1170                .collect(),
1171            _ => db.fuzzy_search_users(&query, 10).await?,
1172        };
1173        let users = users
1174            .into_iter()
1175            .filter(|user| user.id != user_id)
1176            .map(|user| proto::User {
1177                id: user.id.to_proto(),
1178                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1179                github_login: user.github_login,
1180            })
1181            .collect();
1182        response.send(proto::UsersResponse { users })?;
1183        Ok(())
1184    }
1185
1186    async fn request_contact(
1187        self: Arc<Server>,
1188        request: TypedEnvelope<proto::RequestContact>,
1189        response: Response<proto::RequestContact>,
1190    ) -> Result<()> {
1191        let requester_id = self
1192            .store()
1193            .await
1194            .user_id_for_connection(request.sender_id)?;
1195        let responder_id = UserId::from_proto(request.payload.responder_id);
1196        if requester_id == responder_id {
1197            return Err(anyhow!("cannot add yourself as a contact"))?;
1198        }
1199
1200        self.app_state
1201            .db
1202            .send_contact_request(requester_id, responder_id)
1203            .await?;
1204
1205        // Update outgoing contact requests of requester
1206        let mut update = proto::UpdateContacts::default();
1207        update.outgoing_requests.push(responder_id.to_proto());
1208        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1209            self.peer.send(connection_id, update.clone())?;
1210        }
1211
1212        // Update incoming contact requests of responder
1213        let mut update = proto::UpdateContacts::default();
1214        update
1215            .incoming_requests
1216            .push(proto::IncomingContactRequest {
1217                requester_id: requester_id.to_proto(),
1218                should_notify: true,
1219            });
1220        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1221            self.peer.send(connection_id, update.clone())?;
1222        }
1223
1224        response.send(proto::Ack {})?;
1225        Ok(())
1226    }
1227
1228    async fn respond_to_contact_request(
1229        self: Arc<Server>,
1230        request: TypedEnvelope<proto::RespondToContactRequest>,
1231        response: Response<proto::RespondToContactRequest>,
1232    ) -> Result<()> {
1233        let responder_id = self
1234            .store()
1235            .await
1236            .user_id_for_connection(request.sender_id)?;
1237        let requester_id = UserId::from_proto(request.payload.requester_id);
1238        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1239            self.app_state
1240                .db
1241                .dismiss_contact_notification(responder_id, requester_id)
1242                .await?;
1243        } else {
1244            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1245            self.app_state
1246                .db
1247                .respond_to_contact_request(responder_id, requester_id, accept)
1248                .await?;
1249
1250            let store = self.store().await;
1251            // Update responder with new contact
1252            let mut update = proto::UpdateContacts::default();
1253            if accept {
1254                update
1255                    .contacts
1256                    .push(store.contact_for_user(requester_id, false));
1257            }
1258            update
1259                .remove_incoming_requests
1260                .push(requester_id.to_proto());
1261            for connection_id in store.connection_ids_for_user(responder_id) {
1262                self.peer.send(connection_id, update.clone())?;
1263            }
1264
1265            // Update requester with new contact
1266            let mut update = proto::UpdateContacts::default();
1267            if accept {
1268                update
1269                    .contacts
1270                    .push(store.contact_for_user(responder_id, true));
1271            }
1272            update
1273                .remove_outgoing_requests
1274                .push(responder_id.to_proto());
1275            for connection_id in store.connection_ids_for_user(requester_id) {
1276                self.peer.send(connection_id, update.clone())?;
1277            }
1278        }
1279
1280        response.send(proto::Ack {})?;
1281        Ok(())
1282    }
1283
1284    async fn remove_contact(
1285        self: Arc<Server>,
1286        request: TypedEnvelope<proto::RemoveContact>,
1287        response: Response<proto::RemoveContact>,
1288    ) -> Result<()> {
1289        let requester_id = self
1290            .store()
1291            .await
1292            .user_id_for_connection(request.sender_id)?;
1293        let responder_id = UserId::from_proto(request.payload.user_id);
1294        self.app_state
1295            .db
1296            .remove_contact(requester_id, responder_id)
1297            .await?;
1298
1299        // Update outgoing contact requests of requester
1300        let mut update = proto::UpdateContacts::default();
1301        update
1302            .remove_outgoing_requests
1303            .push(responder_id.to_proto());
1304        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1305            self.peer.send(connection_id, update.clone())?;
1306        }
1307
1308        // Update incoming contact requests of responder
1309        let mut update = proto::UpdateContacts::default();
1310        update
1311            .remove_incoming_requests
1312            .push(requester_id.to_proto());
1313        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1314            self.peer.send(connection_id, update.clone())?;
1315        }
1316
1317        response.send(proto::Ack {})?;
1318        Ok(())
1319    }
1320
1321    async fn join_channel(
1322        self: Arc<Self>,
1323        request: TypedEnvelope<proto::JoinChannel>,
1324        response: Response<proto::JoinChannel>,
1325    ) -> Result<()> {
1326        let user_id = self
1327            .store()
1328            .await
1329            .user_id_for_connection(request.sender_id)?;
1330        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1331        if !self
1332            .app_state
1333            .db
1334            .can_user_access_channel(user_id, channel_id)
1335            .await?
1336        {
1337            Err(anyhow!("access denied"))?;
1338        }
1339
1340        self.store_mut()
1341            .await
1342            .join_channel(request.sender_id, channel_id);
1343        let messages = self
1344            .app_state
1345            .db
1346            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1347            .await?
1348            .into_iter()
1349            .map(|msg| proto::ChannelMessage {
1350                id: msg.id.to_proto(),
1351                body: msg.body,
1352                timestamp: msg.sent_at.unix_timestamp() as u64,
1353                sender_id: msg.sender_id.to_proto(),
1354                nonce: Some(msg.nonce.as_u128().into()),
1355            })
1356            .collect::<Vec<_>>();
1357        response.send(proto::JoinChannelResponse {
1358            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1359            messages,
1360        })?;
1361        Ok(())
1362    }
1363
1364    async fn leave_channel(
1365        self: Arc<Self>,
1366        request: TypedEnvelope<proto::LeaveChannel>,
1367    ) -> Result<()> {
1368        let user_id = self
1369            .store()
1370            .await
1371            .user_id_for_connection(request.sender_id)?;
1372        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1373        if !self
1374            .app_state
1375            .db
1376            .can_user_access_channel(user_id, channel_id)
1377            .await?
1378        {
1379            Err(anyhow!("access denied"))?;
1380        }
1381
1382        self.store_mut()
1383            .await
1384            .leave_channel(request.sender_id, channel_id);
1385
1386        Ok(())
1387    }
1388
1389    async fn send_channel_message(
1390        self: Arc<Self>,
1391        request: TypedEnvelope<proto::SendChannelMessage>,
1392        response: Response<proto::SendChannelMessage>,
1393    ) -> Result<()> {
1394        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1395        let user_id;
1396        let connection_ids;
1397        {
1398            let state = self.store().await;
1399            user_id = state.user_id_for_connection(request.sender_id)?;
1400            connection_ids = state.channel_connection_ids(channel_id)?;
1401        }
1402
1403        // Validate the message body.
1404        let body = request.payload.body.trim().to_string();
1405        if body.len() > MAX_MESSAGE_LEN {
1406            return Err(anyhow!("message is too long"))?;
1407        }
1408        if body.is_empty() {
1409            return Err(anyhow!("message can't be blank"))?;
1410        }
1411
1412        let timestamp = OffsetDateTime::now_utc();
1413        let nonce = request
1414            .payload
1415            .nonce
1416            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1417
1418        let message_id = self
1419            .app_state
1420            .db
1421            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1422            .await?
1423            .to_proto();
1424        let message = proto::ChannelMessage {
1425            sender_id: user_id.to_proto(),
1426            id: message_id,
1427            body,
1428            timestamp: timestamp.unix_timestamp() as u64,
1429            nonce: Some(nonce),
1430        };
1431        broadcast(request.sender_id, connection_ids, |conn_id| {
1432            self.peer.send(
1433                conn_id,
1434                proto::ChannelMessageSent {
1435                    channel_id: channel_id.to_proto(),
1436                    message: Some(message.clone()),
1437                },
1438            )
1439        });
1440        response.send(proto::SendChannelMessageResponse {
1441            message: Some(message),
1442        })?;
1443        Ok(())
1444    }
1445
1446    async fn get_channel_messages(
1447        self: Arc<Self>,
1448        request: TypedEnvelope<proto::GetChannelMessages>,
1449        response: Response<proto::GetChannelMessages>,
1450    ) -> Result<()> {
1451        let user_id = self
1452            .store()
1453            .await
1454            .user_id_for_connection(request.sender_id)?;
1455        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1456        if !self
1457            .app_state
1458            .db
1459            .can_user_access_channel(user_id, channel_id)
1460            .await?
1461        {
1462            Err(anyhow!("access denied"))?;
1463        }
1464
1465        let messages = self
1466            .app_state
1467            .db
1468            .get_channel_messages(
1469                channel_id,
1470                MESSAGE_COUNT_PER_PAGE,
1471                Some(MessageId::from_proto(request.payload.before_message_id)),
1472            )
1473            .await?
1474            .into_iter()
1475            .map(|msg| proto::ChannelMessage {
1476                id: msg.id.to_proto(),
1477                body: msg.body,
1478                timestamp: msg.sent_at.unix_timestamp() as u64,
1479                sender_id: msg.sender_id.to_proto(),
1480                nonce: Some(msg.nonce.as_u128().into()),
1481            })
1482            .collect::<Vec<_>>();
1483        response.send(proto::GetChannelMessagesResponse {
1484            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1485            messages,
1486        })?;
1487        Ok(())
1488    }
1489
1490    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1491        #[cfg(test)]
1492        tokio::task::yield_now().await;
1493        let guard = self.store.read().await;
1494        #[cfg(test)]
1495        tokio::task::yield_now().await;
1496        StoreReadGuard {
1497            guard,
1498            _not_send: PhantomData,
1499        }
1500    }
1501
1502    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1503        #[cfg(test)]
1504        tokio::task::yield_now().await;
1505        let guard = self.store.write().await;
1506        #[cfg(test)]
1507        tokio::task::yield_now().await;
1508        StoreWriteGuard {
1509            guard,
1510            _not_send: PhantomData,
1511        }
1512    }
1513
1514    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1515        ServerSnapshot {
1516            store: self.store.read().await,
1517            peer: &self.peer,
1518        }
1519    }
1520}
1521
1522impl<'a> Deref for StoreReadGuard<'a> {
1523    type Target = Store;
1524
1525    fn deref(&self) -> &Self::Target {
1526        &*self.guard
1527    }
1528}
1529
1530impl<'a> Deref for StoreWriteGuard<'a> {
1531    type Target = Store;
1532
1533    fn deref(&self) -> &Self::Target {
1534        &*self.guard
1535    }
1536}
1537
1538impl<'a> DerefMut for StoreWriteGuard<'a> {
1539    fn deref_mut(&mut self) -> &mut Self::Target {
1540        &mut *self.guard
1541    }
1542}
1543
1544impl<'a> Drop for StoreWriteGuard<'a> {
1545    fn drop(&mut self) {
1546        #[cfg(test)]
1547        self.check_invariants();
1548
1549        let metrics = self.metrics();
1550
1551        METRIC_CONNECTIONS.set(metrics.connections as _);
1552        METRIC_PROJECTS.set(metrics.registered_projects as _);
1553        METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1554
1555        tracing::info!(
1556            connections = metrics.connections,
1557            registered_projects = metrics.registered_projects,
1558            shared_projects = metrics.shared_projects,
1559            "metrics"
1560        );
1561    }
1562}
1563
1564impl Executor for RealExecutor {
1565    type Sleep = Sleep;
1566
1567    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1568        tokio::task::spawn(future);
1569    }
1570
1571    fn sleep(&self, duration: Duration) -> Self::Sleep {
1572        tokio::time::sleep(duration)
1573    }
1574}
1575
1576fn broadcast<F>(
1577    sender_id: ConnectionId,
1578    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1579    mut f: F,
1580) where
1581    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1582{
1583    for receiver_id in receiver_ids {
1584        if receiver_id != sender_id {
1585            f(receiver_id).trace_err();
1586        }
1587    }
1588}
1589
1590lazy_static! {
1591    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1592}
1593
1594pub struct ProtocolVersion(u32);
1595
1596impl Header for ProtocolVersion {
1597    fn name() -> &'static HeaderName {
1598        &ZED_PROTOCOL_VERSION
1599    }
1600
1601    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1602    where
1603        Self: Sized,
1604        I: Iterator<Item = &'i axum::http::HeaderValue>,
1605    {
1606        let version = values
1607            .next()
1608            .ok_or_else(|| axum::headers::Error::invalid())?
1609            .to_str()
1610            .map_err(|_| axum::headers::Error::invalid())?
1611            .parse()
1612            .map_err(|_| axum::headers::Error::invalid())?;
1613        Ok(Self(version))
1614    }
1615
1616    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1617        values.extend([self.0.to_string().parse().unwrap()]);
1618    }
1619}
1620
1621pub fn routes(server: Arc<Server>) -> Router<Body> {
1622    Router::new()
1623        .route("/rpc", get(handle_websocket_request))
1624        .layer(
1625            ServiceBuilder::new()
1626                .layer(Extension(server.app_state.clone()))
1627                .layer(middleware::from_fn(auth::validate_header))
1628                .layer(Extension(server)),
1629        )
1630        .route("/metrics", get(handle_metrics))
1631}
1632
1633pub async fn handle_websocket_request(
1634    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1635    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1636    Extension(server): Extension<Arc<Server>>,
1637    Extension(user): Extension<User>,
1638    ws: WebSocketUpgrade,
1639) -> axum::response::Response {
1640    if protocol_version != rpc::PROTOCOL_VERSION {
1641        return (
1642            StatusCode::UPGRADE_REQUIRED,
1643            "client must be upgraded".to_string(),
1644        )
1645            .into_response();
1646    }
1647    let socket_address = socket_address.to_string();
1648    ws.on_upgrade(move |socket| {
1649        use util::ResultExt;
1650        let socket = socket
1651            .map_ok(to_tungstenite_message)
1652            .err_into()
1653            .with(|message| async move { Ok(to_axum_message(message)) });
1654        let connection = Connection::new(Box::pin(socket));
1655        async move {
1656            server
1657                .handle_connection(connection, socket_address, user, None, RealExecutor)
1658                .await
1659                .log_err();
1660        }
1661    })
1662}
1663
1664pub async fn handle_metrics() -> axum::response::Response {
1665    let encoder = prometheus::TextEncoder::new();
1666    let metric_families = prometheus::gather();
1667    match encoder.encode_to_string(&metric_families) {
1668        Ok(string) => (StatusCode::OK, string).into_response(),
1669        Err(error) => (
1670            StatusCode::INTERNAL_SERVER_ERROR,
1671            format!("failed to encode metrics {:?}", error),
1672        )
1673            .into_response(),
1674    }
1675}
1676
1677fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1678    match message {
1679        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1680        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1681        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1682        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1683        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1684            code: frame.code.into(),
1685            reason: frame.reason,
1686        })),
1687    }
1688}
1689
1690fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1691    match message {
1692        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1693        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1694        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1695        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1696        AxumMessage::Close(frame) => {
1697            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1698                code: frame.code.into(),
1699                reason: frame.reason,
1700            }))
1701        }
1702    }
1703}
1704
1705pub trait ResultExt {
1706    type Ok;
1707
1708    fn trace_err(self) -> Option<Self::Ok>;
1709}
1710
1711impl<T, E> ResultExt for Result<T, E>
1712where
1713    E: std::fmt::Debug,
1714{
1715    type Ok = T;
1716
1717    fn trace_err(self) -> Option<T> {
1718        match self {
1719            Ok(value) => Some(value),
1720            Err(error) => {
1721                tracing::error!("{:?}", error);
1722                None
1723            }
1724        }
1725    }
1726}