rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use prometheus::{register_int_gauge, IntGauge};
  33use rpc::{
  34    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  35    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  36};
  37use serde::{Serialize, Serializer};
  38use std::{
  39    any::TypeId,
  40    future::Future,
  41    marker::PhantomData,
  42    net::SocketAddr,
  43    ops::{Deref, DerefMut},
  44    rc::Rc,
  45    sync::{
  46        atomic::{AtomicBool, Ordering::SeqCst},
  47        Arc,
  48    },
  49    time::Duration,
  50};
  51use time::OffsetDateTime;
  52use tokio::{
  53    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  54    time::Sleep,
  55};
  56use tower::ServiceBuilder;
  57use tracing::{info_span, instrument, Instrument};
  58
  59pub use store::{Store, Worktree};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_PROJECTS: IntGauge =
  65        register_int_gauge!("projects", "number of open projects").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    server: Arc<Server>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.server.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88
  89    fn into_receipt(self) -> Receipt<R> {
  90        self.responded.store(true, SeqCst);
  91        self.receipt
  92    }
  93}
  94
  95pub struct Server {
  96    peer: Arc<Peer>,
  97    pub(crate) store: RwLock<Store>,
  98    app_state: Arc<AppState>,
  99    handlers: HashMap<TypeId, MessageHandler>,
 100    notifications: Option<mpsc::UnboundedSender<()>>,
 101}
 102
 103pub trait Executor: Send + Clone {
 104    type Sleep: Send + Future;
 105    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 106    fn sleep(&self, duration: Duration) -> Self::Sleep;
 107}
 108
 109#[derive(Clone)]
 110pub struct RealExecutor;
 111
 112const MESSAGE_COUNT_PER_PAGE: usize = 100;
 113const MAX_MESSAGE_LEN: usize = 1024;
 114
 115struct StoreReadGuard<'a> {
 116    guard: RwLockReadGuard<'a, Store>,
 117    _not_send: PhantomData<Rc<()>>,
 118}
 119
 120struct StoreWriteGuard<'a> {
 121    guard: RwLockWriteGuard<'a, Store>,
 122    _not_send: PhantomData<Rc<()>>,
 123}
 124
 125#[derive(Serialize)]
 126pub struct ServerSnapshot<'a> {
 127    peer: &'a Peer,
 128    #[serde(serialize_with = "serialize_deref")]
 129    store: RwLockReadGuard<'a, Store>,
 130}
 131
 132pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 133where
 134    S: Serializer,
 135    T: Deref<Target = U>,
 136    U: Serialize,
 137{
 138    Serialize::serialize(value.deref(), serializer)
 139}
 140
 141impl Server {
 142    pub fn new(
 143        app_state: Arc<AppState>,
 144        notifications: Option<mpsc::UnboundedSender<()>>,
 145    ) -> Arc<Self> {
 146        let mut server = Self {
 147            peer: Peer::new(),
 148            app_state,
 149            store: Default::default(),
 150            handlers: Default::default(),
 151            notifications,
 152        };
 153
 154        server
 155            .add_request_handler(Server::ping)
 156            .add_request_handler(Server::register_project)
 157            .add_request_handler(Server::unregister_project)
 158            .add_request_handler(Server::join_project)
 159            .add_message_handler(Server::leave_project)
 160            .add_message_handler(Server::respond_to_join_project_request)
 161            .add_message_handler(Server::update_project)
 162            .add_request_handler(Server::update_worktree)
 163            .add_message_handler(Server::start_language_server)
 164            .add_message_handler(Server::update_language_server)
 165            .add_message_handler(Server::update_diagnostic_summary)
 166            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 167            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 169            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 170            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 171            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 172            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 173            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 174            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 175            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 176            .add_request_handler(
 177                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 178            )
 179            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 180            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 181            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 182            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 183            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 184            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 185            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 186            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 187            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 188            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 189            .add_request_handler(Server::update_buffer)
 190            .add_message_handler(Server::update_buffer_file)
 191            .add_message_handler(Server::buffer_reloaded)
 192            .add_message_handler(Server::buffer_saved)
 193            .add_request_handler(Server::save_buffer)
 194            .add_request_handler(Server::get_channels)
 195            .add_request_handler(Server::get_users)
 196            .add_request_handler(Server::fuzzy_search_users)
 197            .add_request_handler(Server::request_contact)
 198            .add_request_handler(Server::remove_contact)
 199            .add_request_handler(Server::respond_to_contact_request)
 200            .add_request_handler(Server::join_channel)
 201            .add_message_handler(Server::leave_channel)
 202            .add_request_handler(Server::send_channel_message)
 203            .add_request_handler(Server::follow)
 204            .add_message_handler(Server::unfollow)
 205            .add_message_handler(Server::update_followers)
 206            .add_request_handler(Server::get_channel_messages);
 207
 208        Arc::new(server)
 209    }
 210
 211    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 212    where
 213        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 214        Fut: 'static + Send + Future<Output = Result<()>>,
 215        M: EnvelopedMessage,
 216    {
 217        let prev_handler = self.handlers.insert(
 218            TypeId::of::<M>(),
 219            Box::new(move |server, envelope| {
 220                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 221                let span = info_span!(
 222                    "handle message",
 223                    payload_type = envelope.payload_type_name()
 224                );
 225                span.in_scope(|| {
 226                    tracing::info!(
 227                        payload = format!("{:?}", envelope.payload).as_str(),
 228                        "message payload"
 229                    );
 230                });
 231                let future = (handler)(server, *envelope);
 232                async move {
 233                    if let Err(error) = future.await {
 234                        tracing::error!(%error, "error handling message");
 235                    }
 236                }
 237                .instrument(span)
 238                .boxed()
 239            }),
 240        );
 241        if prev_handler.is_some() {
 242            panic!("registered a handler for the same message twice");
 243        }
 244        self
 245    }
 246
 247    /// Handle a request while holding a lock to the store. This is useful when we're registering
 248    /// a connection but we want to respond on the connection before anybody else can send on it.
 249    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 250    where
 251        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 252        Fut: Send + Future<Output = Result<()>>,
 253        M: RequestMessage,
 254    {
 255        let handler = Arc::new(handler);
 256        self.add_message_handler(move |server, envelope| {
 257            let receipt = envelope.receipt();
 258            let handler = handler.clone();
 259            async move {
 260                let responded = Arc::new(AtomicBool::default());
 261                let response = Response {
 262                    server: server.clone(),
 263                    responded: responded.clone(),
 264                    receipt: envelope.receipt(),
 265                };
 266                match (handler)(server.clone(), envelope, response).await {
 267                    Ok(()) => {
 268                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 269                            Ok(())
 270                        } else {
 271                            Err(anyhow!("handler did not send a response"))?
 272                        }
 273                    }
 274                    Err(error) => {
 275                        server.peer.respond_with_error(
 276                            receipt,
 277                            proto::Error {
 278                                message: error.to_string(),
 279                            },
 280                        )?;
 281                        Err(error)
 282                    }
 283                }
 284            }
 285        })
 286    }
 287
 288    pub fn handle_connection<E: Executor>(
 289        self: &Arc<Self>,
 290        connection: Connection,
 291        address: String,
 292        user: User,
 293        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 294        executor: E,
 295    ) -> impl Future<Output = Result<()>> {
 296        let mut this = self.clone();
 297        let user_id = user.id;
 298        let login = user.github_login;
 299        let span = info_span!("handle connection", %user_id, %login, %address);
 300        async move {
 301            let (connection_id, handle_io, mut incoming_rx) = this
 302                .peer
 303                .add_connection(connection, {
 304                    let executor = executor.clone();
 305                    move |duration| {
 306                        let timer = executor.sleep(duration);
 307                        async move {
 308                            timer.await;
 309                        }
 310                    }
 311                })
 312                .await;
 313
 314            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 315
 316            if let Some(send_connection_id) = send_connection_id.as_mut() {
 317                let _ = send_connection_id.send(connection_id).await;
 318            }
 319
 320            if !user.connected_once {
 321                this.peer.send(connection_id, proto::ShowContacts {})?;
 322                this.app_state.db.set_user_connected_once(user_id, true).await?;
 323            }
 324
 325            let (contacts, invite_code) = future::try_join(
 326                this.app_state.db.get_contacts(user_id),
 327                this.app_state.db.get_invite_code_for_user(user_id)
 328            ).await?;
 329
 330            {
 331                let mut store = this.store_mut().await;
 332                store.add_connection(connection_id, user_id);
 333                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 334
 335                if let Some((code, count)) = invite_code {
 336                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 337                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 338                        count,
 339                    })?;
 340                }
 341            }
 342            this.update_user_contacts(user_id).await?;
 343
 344            let handle_io = handle_io.fuse();
 345            futures::pin_mut!(handle_io);
 346            loop {
 347                let next_message = incoming_rx.next().fuse();
 348                futures::pin_mut!(next_message);
 349                futures::select_biased! {
 350                    result = handle_io => {
 351                        if let Err(error) = result {
 352                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 353                        }
 354                        break;
 355                    }
 356                    message = next_message => {
 357                        if let Some(message) = message {
 358                            let type_name = message.payload_type_name();
 359                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 360                            async {
 361                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 362                                    let notifications = this.notifications.clone();
 363                                    let is_background = message.is_background();
 364                                    let handle_message = (handler)(this.clone(), message);
 365                                    let handle_message = async move {
 366                                        handle_message.await;
 367                                        if let Some(mut notifications) = notifications {
 368                                            let _ = notifications.send(()).await;
 369                                        }
 370                                    };
 371                                    if is_background {
 372                                        executor.spawn_detached(handle_message);
 373                                    } else {
 374                                        handle_message.await;
 375                                    }
 376                                } else {
 377                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 378                                }
 379                            }.instrument(span).await;
 380                        } else {
 381                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 382                            break;
 383                        }
 384                    }
 385                }
 386            }
 387
 388            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 389            if let Err(error) = this.sign_out(connection_id).await {
 390                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 391            }
 392
 393            Ok(())
 394        }.instrument(span)
 395    }
 396
 397    #[instrument(skip(self), err)]
 398    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 399        self.peer.disconnect(connection_id);
 400
 401        let removed_user_id = {
 402            let mut store = self.store_mut().await;
 403            let removed_connection = store.remove_connection(connection_id)?;
 404
 405            for (project_id, project) in removed_connection.hosted_projects {
 406                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 407                    self.peer
 408                        .send(conn_id, proto::UnregisterProject { project_id })
 409                });
 410
 411                for (_, receipts) in project.join_requests {
 412                    for receipt in receipts {
 413                        self.peer.respond(
 414                            receipt,
 415                            proto::JoinProjectResponse {
 416                                variant: Some(proto::join_project_response::Variant::Decline(
 417                                    proto::join_project_response::Decline {
 418                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 419                                    },
 420                                )),
 421                            },
 422                        )?;
 423                    }
 424                }
 425            }
 426
 427            for project_id in removed_connection.guest_project_ids {
 428                if let Some(project) = store.project(project_id).trace_err() {
 429                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 430                        self.peer.send(
 431                            conn_id,
 432                            proto::RemoveProjectCollaborator {
 433                                project_id,
 434                                peer_id: connection_id.0,
 435                            },
 436                        )
 437                    });
 438                    if project.guests.is_empty() {
 439                        self.peer
 440                            .send(
 441                                project.host_connection_id,
 442                                proto::ProjectUnshared { project_id },
 443                            )
 444                            .trace_err();
 445                    }
 446                }
 447            }
 448
 449            removed_connection.user_id
 450        };
 451
 452        self.update_user_contacts(removed_user_id).await?;
 453
 454        Ok(())
 455    }
 456
 457    pub async fn invite_code_redeemed(
 458        self: &Arc<Self>,
 459        code: &str,
 460        invitee_id: UserId,
 461    ) -> Result<()> {
 462        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 463        let store = self.store().await;
 464        let invitee_contact = store.contact_for_user(invitee_id, true);
 465        for connection_id in store.connection_ids_for_user(user.id) {
 466            self.peer.send(
 467                connection_id,
 468                proto::UpdateContacts {
 469                    contacts: vec![invitee_contact.clone()],
 470                    ..Default::default()
 471                },
 472            )?;
 473            self.peer.send(
 474                connection_id,
 475                proto::UpdateInviteInfo {
 476                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 477                    count: user.invite_count as u32,
 478                },
 479            )?;
 480        }
 481        Ok(())
 482    }
 483
 484    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 485        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 486            if let Some(invite_code) = &user.invite_code {
 487                let store = self.store().await;
 488                for connection_id in store.connection_ids_for_user(user_id) {
 489                    self.peer.send(
 490                        connection_id,
 491                        proto::UpdateInviteInfo {
 492                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
 493                            count: user.invite_count as u32,
 494                        },
 495                    )?;
 496                }
 497            }
 498        }
 499        Ok(())
 500    }
 501
 502    async fn ping(
 503        self: Arc<Server>,
 504        _: TypedEnvelope<proto::Ping>,
 505        response: Response<proto::Ping>,
 506    ) -> Result<()> {
 507        response.send(proto::Ack {})?;
 508        Ok(())
 509    }
 510
 511    async fn register_project(
 512        self: Arc<Server>,
 513        request: TypedEnvelope<proto::RegisterProject>,
 514        response: Response<proto::RegisterProject>,
 515    ) -> Result<()> {
 516        let project_id;
 517        {
 518            let mut state = self.store_mut().await;
 519            let user_id = state.user_id_for_connection(request.sender_id)?;
 520            project_id = state.register_project(request.sender_id, user_id);
 521        };
 522
 523        response.send(proto::RegisterProjectResponse { project_id })?;
 524
 525        Ok(())
 526    }
 527
 528    async fn unregister_project(
 529        self: Arc<Server>,
 530        request: TypedEnvelope<proto::UnregisterProject>,
 531        response: Response<proto::UnregisterProject>,
 532    ) -> Result<()> {
 533        let (user_id, project) = {
 534            let mut state = self.store_mut().await;
 535            let project =
 536                state.unregister_project(request.payload.project_id, request.sender_id)?;
 537            (state.user_id_for_connection(request.sender_id)?, project)
 538        };
 539
 540        broadcast(
 541            request.sender_id,
 542            project.guests.keys().copied(),
 543            |conn_id| {
 544                self.peer.send(
 545                    conn_id,
 546                    proto::UnregisterProject {
 547                        project_id: request.payload.project_id,
 548                    },
 549                )
 550            },
 551        );
 552        for (_, receipts) in project.join_requests {
 553            for receipt in receipts {
 554                self.peer.respond(
 555                    receipt,
 556                    proto::JoinProjectResponse {
 557                        variant: Some(proto::join_project_response::Variant::Decline(
 558                            proto::join_project_response::Decline {
 559                                reason: proto::join_project_response::decline::Reason::Closed
 560                                    as i32,
 561                            },
 562                        )),
 563                    },
 564                )?;
 565            }
 566        }
 567
 568        // Send out the `UpdateContacts` message before responding to the unregister
 569        // request. This way, when the project's host can keep track of the project's
 570        // remote id until after they've received the `UpdateContacts` message for
 571        // themself.
 572        self.update_user_contacts(user_id).await?;
 573        response.send(proto::Ack {})?;
 574
 575        Ok(())
 576    }
 577
 578    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 579        let contacts = self.app_state.db.get_contacts(user_id).await?;
 580        let store = self.store().await;
 581        let updated_contact = store.contact_for_user(user_id, false);
 582        for contact in contacts {
 583            if let db::Contact::Accepted {
 584                user_id: contact_user_id,
 585                ..
 586            } = contact
 587            {
 588                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 589                    self.peer
 590                        .send(
 591                            contact_conn_id,
 592                            proto::UpdateContacts {
 593                                contacts: vec![updated_contact.clone()],
 594                                remove_contacts: Default::default(),
 595                                incoming_requests: Default::default(),
 596                                remove_incoming_requests: Default::default(),
 597                                outgoing_requests: Default::default(),
 598                                remove_outgoing_requests: Default::default(),
 599                            },
 600                        )
 601                        .trace_err();
 602                }
 603            }
 604        }
 605        Ok(())
 606    }
 607
 608    async fn join_project(
 609        self: Arc<Server>,
 610        request: TypedEnvelope<proto::JoinProject>,
 611        response: Response<proto::JoinProject>,
 612    ) -> Result<()> {
 613        let project_id = request.payload.project_id;
 614
 615        let host_user_id;
 616        let guest_user_id;
 617        let host_connection_id;
 618        {
 619            let state = self.store().await;
 620            let project = state.project(project_id)?;
 621            host_user_id = project.host_user_id;
 622            host_connection_id = project.host_connection_id;
 623            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 624        };
 625
 626        tracing::info!(project_id, %host_user_id, %host_connection_id, "join project");
 627        let has_contact = self
 628            .app_state
 629            .db
 630            .has_contact(guest_user_id, host_user_id)
 631            .await?;
 632        if !has_contact {
 633            return Err(anyhow!("no such project"))?;
 634        }
 635
 636        self.store_mut().await.request_join_project(
 637            guest_user_id,
 638            project_id,
 639            response.into_receipt(),
 640        )?;
 641        self.peer.send(
 642            host_connection_id,
 643            proto::RequestJoinProject {
 644                project_id,
 645                requester_id: guest_user_id.to_proto(),
 646            },
 647        )?;
 648        Ok(())
 649    }
 650
 651    async fn respond_to_join_project_request(
 652        self: Arc<Server>,
 653        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 654    ) -> Result<()> {
 655        let host_user_id;
 656
 657        {
 658            let mut state = self.store_mut().await;
 659            let project_id = request.payload.project_id;
 660            let project = state.project(project_id)?;
 661            if project.host_connection_id != request.sender_id {
 662                Err(anyhow!("no such connection"))?;
 663            }
 664
 665            host_user_id = project.host_user_id;
 666            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 667
 668            if !request.payload.allow {
 669                let receipts = state
 670                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 671                    .ok_or_else(|| anyhow!("no such request"))?;
 672                for receipt in receipts {
 673                    self.peer.respond(
 674                        receipt,
 675                        proto::JoinProjectResponse {
 676                            variant: Some(proto::join_project_response::Variant::Decline(
 677                                proto::join_project_response::Decline {
 678                                    reason: proto::join_project_response::decline::Reason::Declined
 679                                        as i32,
 680                                },
 681                            )),
 682                        },
 683                    )?;
 684                }
 685                return Ok(());
 686            }
 687
 688            let (receipts_with_replica_ids, project) = state
 689                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 690                .ok_or_else(|| anyhow!("no such request"))?;
 691
 692            let peer_count = project.guests.len();
 693            let mut collaborators = Vec::with_capacity(peer_count);
 694            collaborators.push(proto::Collaborator {
 695                peer_id: project.host_connection_id.0,
 696                replica_id: 0,
 697                user_id: project.host_user_id.to_proto(),
 698            });
 699            let worktrees = project
 700                .worktrees
 701                .iter()
 702                .filter_map(|(id, shared_worktree)| {
 703                    let worktree = project.worktrees.get(&id)?;
 704                    Some(proto::Worktree {
 705                        id: *id,
 706                        root_name: worktree.root_name.clone(),
 707                        entries: shared_worktree.entries.values().cloned().collect(),
 708                        diagnostic_summaries: shared_worktree
 709                            .diagnostic_summaries
 710                            .values()
 711                            .cloned()
 712                            .collect(),
 713                        visible: worktree.visible,
 714                        scan_id: shared_worktree.scan_id,
 715                    })
 716                })
 717                .collect::<Vec<_>>();
 718
 719            // Add all guests other than the requesting user's own connections as collaborators
 720            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 721                if receipts_with_replica_ids
 722                    .iter()
 723                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 724                {
 725                    collaborators.push(proto::Collaborator {
 726                        peer_id: peer_conn_id.0,
 727                        replica_id: *peer_replica_id as u32,
 728                        user_id: peer_user_id.to_proto(),
 729                    });
 730                }
 731            }
 732
 733            for conn_id in project.connection_ids() {
 734                for (receipt, replica_id) in &receipts_with_replica_ids {
 735                    if conn_id != receipt.sender_id {
 736                        self.peer.send(
 737                            conn_id,
 738                            proto::AddProjectCollaborator {
 739                                project_id,
 740                                collaborator: Some(proto::Collaborator {
 741                                    peer_id: receipt.sender_id.0,
 742                                    replica_id: *replica_id as u32,
 743                                    user_id: guest_user_id.to_proto(),
 744                                }),
 745                            },
 746                        )?;
 747                    }
 748                }
 749            }
 750
 751            for (receipt, replica_id) in receipts_with_replica_ids {
 752                self.peer.respond(
 753                    receipt,
 754                    proto::JoinProjectResponse {
 755                        variant: Some(proto::join_project_response::Variant::Accept(
 756                            proto::join_project_response::Accept {
 757                                worktrees: worktrees.clone(),
 758                                replica_id: replica_id as u32,
 759                                collaborators: collaborators.clone(),
 760                                language_servers: project.language_servers.clone(),
 761                            },
 762                        )),
 763                    },
 764                )?;
 765            }
 766        }
 767
 768        self.update_user_contacts(host_user_id).await?;
 769        Ok(())
 770    }
 771
 772    async fn leave_project(
 773        self: Arc<Server>,
 774        request: TypedEnvelope<proto::LeaveProject>,
 775    ) -> Result<()> {
 776        let sender_id = request.sender_id;
 777        let project_id = request.payload.project_id;
 778        let project;
 779        {
 780            let mut store = self.store_mut().await;
 781            project = store.leave_project(sender_id, project_id)?;
 782            tracing::info!(
 783                project_id,
 784                host_user_id = %project.host_user_id,
 785                host_connection_id = %project.host_connection_id,
 786                "leave project"
 787            );
 788
 789            if project.remove_collaborator {
 790                broadcast(sender_id, project.connection_ids, |conn_id| {
 791                    self.peer.send(
 792                        conn_id,
 793                        proto::RemoveProjectCollaborator {
 794                            project_id,
 795                            peer_id: sender_id.0,
 796                        },
 797                    )
 798                });
 799            }
 800
 801            if let Some(requester_id) = project.cancel_request {
 802                self.peer.send(
 803                    project.host_connection_id,
 804                    proto::JoinProjectRequestCancelled {
 805                        project_id,
 806                        requester_id: requester_id.to_proto(),
 807                    },
 808                )?;
 809            }
 810
 811            if project.unshare {
 812                self.peer.send(
 813                    project.host_connection_id,
 814                    proto::ProjectUnshared { project_id },
 815                )?;
 816            }
 817        }
 818        self.update_user_contacts(project.host_user_id).await?;
 819        Ok(())
 820    }
 821
 822    async fn update_project(
 823        self: Arc<Server>,
 824        request: TypedEnvelope<proto::UpdateProject>,
 825    ) -> Result<()> {
 826        let user_id;
 827        {
 828            let mut state = self.store_mut().await;
 829            user_id = state.user_id_for_connection(request.sender_id)?;
 830            let guest_connection_ids = state
 831                .read_project(request.payload.project_id, request.sender_id)?
 832                .guest_connection_ids();
 833            state.update_project(
 834                request.payload.project_id,
 835                &request.payload.worktrees,
 836                request.sender_id,
 837            )?;
 838            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 839                self.peer
 840                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 841            });
 842        };
 843        self.update_user_contacts(user_id).await?;
 844        Ok(())
 845    }
 846
 847    async fn update_worktree(
 848        self: Arc<Server>,
 849        request: TypedEnvelope<proto::UpdateWorktree>,
 850        response: Response<proto::UpdateWorktree>,
 851    ) -> Result<()> {
 852        let (connection_ids, metadata_changed) = {
 853            let mut store = self.store_mut().await;
 854            let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
 855                request.sender_id,
 856                request.payload.project_id,
 857                request.payload.worktree_id,
 858                &request.payload.root_name,
 859                &request.payload.removed_entries,
 860                &request.payload.updated_entries,
 861                request.payload.scan_id,
 862            )?;
 863            for (extension, count) in extension_counts {
 864                tracing::info!(
 865                    project_id = request.payload.project_id,
 866                    worktree_id = request.payload.worktree_id,
 867                    ?extension,
 868                    %count,
 869                    "worktree updated"
 870                );
 871            }
 872            (connection_ids, metadata_changed)
 873        };
 874
 875        broadcast(request.sender_id, connection_ids, |connection_id| {
 876            self.peer
 877                .forward_send(request.sender_id, connection_id, request.payload.clone())
 878        });
 879        if metadata_changed {
 880            let user_id = self
 881                .store()
 882                .await
 883                .user_id_for_connection(request.sender_id)?;
 884            self.update_user_contacts(user_id).await?;
 885        }
 886        response.send(proto::Ack {})?;
 887        Ok(())
 888    }
 889
 890    async fn update_diagnostic_summary(
 891        self: Arc<Server>,
 892        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 893    ) -> Result<()> {
 894        let summary = request
 895            .payload
 896            .summary
 897            .clone()
 898            .ok_or_else(|| anyhow!("invalid summary"))?;
 899        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 900            request.payload.project_id,
 901            request.payload.worktree_id,
 902            request.sender_id,
 903            summary,
 904        )?;
 905
 906        broadcast(request.sender_id, receiver_ids, |connection_id| {
 907            self.peer
 908                .forward_send(request.sender_id, connection_id, request.payload.clone())
 909        });
 910        Ok(())
 911    }
 912
 913    async fn start_language_server(
 914        self: Arc<Server>,
 915        request: TypedEnvelope<proto::StartLanguageServer>,
 916    ) -> Result<()> {
 917        let receiver_ids = self.store_mut().await.start_language_server(
 918            request.payload.project_id,
 919            request.sender_id,
 920            request
 921                .payload
 922                .server
 923                .clone()
 924                .ok_or_else(|| anyhow!("invalid language server"))?,
 925        )?;
 926        broadcast(request.sender_id, receiver_ids, |connection_id| {
 927            self.peer
 928                .forward_send(request.sender_id, connection_id, request.payload.clone())
 929        });
 930        Ok(())
 931    }
 932
 933    async fn update_language_server(
 934        self: Arc<Server>,
 935        request: TypedEnvelope<proto::UpdateLanguageServer>,
 936    ) -> Result<()> {
 937        let receiver_ids = self
 938            .store()
 939            .await
 940            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 941        broadcast(request.sender_id, receiver_ids, |connection_id| {
 942            self.peer
 943                .forward_send(request.sender_id, connection_id, request.payload.clone())
 944        });
 945        Ok(())
 946    }
 947
 948    async fn forward_project_request<T>(
 949        self: Arc<Server>,
 950        request: TypedEnvelope<T>,
 951        response: Response<T>,
 952    ) -> Result<()>
 953    where
 954        T: EntityMessage + RequestMessage,
 955    {
 956        let host_connection_id = self
 957            .store()
 958            .await
 959            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 960            .host_connection_id;
 961
 962        response.send(
 963            self.peer
 964                .forward_request(request.sender_id, host_connection_id, request.payload)
 965                .await?,
 966        )?;
 967        Ok(())
 968    }
 969
 970    async fn save_buffer(
 971        self: Arc<Server>,
 972        request: TypedEnvelope<proto::SaveBuffer>,
 973        response: Response<proto::SaveBuffer>,
 974    ) -> Result<()> {
 975        let host = self
 976            .store()
 977            .await
 978            .read_project(request.payload.project_id, request.sender_id)?
 979            .host_connection_id;
 980        let response_payload = self
 981            .peer
 982            .forward_request(request.sender_id, host, request.payload.clone())
 983            .await?;
 984
 985        let mut guests = self
 986            .store()
 987            .await
 988            .read_project(request.payload.project_id, request.sender_id)?
 989            .connection_ids();
 990        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 991        broadcast(host, guests, |conn_id| {
 992            self.peer
 993                .forward_send(host, conn_id, response_payload.clone())
 994        });
 995        response.send(response_payload)?;
 996        Ok(())
 997    }
 998
 999    async fn update_buffer(
1000        self: Arc<Server>,
1001        request: TypedEnvelope<proto::UpdateBuffer>,
1002        response: Response<proto::UpdateBuffer>,
1003    ) -> Result<()> {
1004        let receiver_ids = self
1005            .store()
1006            .await
1007            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1008        broadcast(request.sender_id, receiver_ids, |connection_id| {
1009            self.peer
1010                .forward_send(request.sender_id, connection_id, request.payload.clone())
1011        });
1012        response.send(proto::Ack {})?;
1013        Ok(())
1014    }
1015
1016    async fn update_buffer_file(
1017        self: Arc<Server>,
1018        request: TypedEnvelope<proto::UpdateBufferFile>,
1019    ) -> Result<()> {
1020        let receiver_ids = self
1021            .store()
1022            .await
1023            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1024        broadcast(request.sender_id, receiver_ids, |connection_id| {
1025            self.peer
1026                .forward_send(request.sender_id, connection_id, request.payload.clone())
1027        });
1028        Ok(())
1029    }
1030
1031    async fn buffer_reloaded(
1032        self: Arc<Server>,
1033        request: TypedEnvelope<proto::BufferReloaded>,
1034    ) -> Result<()> {
1035        let receiver_ids = self
1036            .store()
1037            .await
1038            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1039        broadcast(request.sender_id, receiver_ids, |connection_id| {
1040            self.peer
1041                .forward_send(request.sender_id, connection_id, request.payload.clone())
1042        });
1043        Ok(())
1044    }
1045
1046    async fn buffer_saved(
1047        self: Arc<Server>,
1048        request: TypedEnvelope<proto::BufferSaved>,
1049    ) -> Result<()> {
1050        let receiver_ids = self
1051            .store()
1052            .await
1053            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1054        broadcast(request.sender_id, receiver_ids, |connection_id| {
1055            self.peer
1056                .forward_send(request.sender_id, connection_id, request.payload.clone())
1057        });
1058        Ok(())
1059    }
1060
1061    async fn follow(
1062        self: Arc<Self>,
1063        request: TypedEnvelope<proto::Follow>,
1064        response: Response<proto::Follow>,
1065    ) -> Result<()> {
1066        let leader_id = ConnectionId(request.payload.leader_id);
1067        let follower_id = request.sender_id;
1068        if !self
1069            .store()
1070            .await
1071            .project_connection_ids(request.payload.project_id, follower_id)?
1072            .contains(&leader_id)
1073        {
1074            Err(anyhow!("no such peer"))?;
1075        }
1076        let mut response_payload = self
1077            .peer
1078            .forward_request(request.sender_id, leader_id, request.payload)
1079            .await?;
1080        response_payload
1081            .views
1082            .retain(|view| view.leader_id != Some(follower_id.0));
1083        response.send(response_payload)?;
1084        Ok(())
1085    }
1086
1087    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1088        let leader_id = ConnectionId(request.payload.leader_id);
1089        if !self
1090            .store()
1091            .await
1092            .project_connection_ids(request.payload.project_id, request.sender_id)?
1093            .contains(&leader_id)
1094        {
1095            Err(anyhow!("no such peer"))?;
1096        }
1097        self.peer
1098            .forward_send(request.sender_id, leader_id, request.payload)?;
1099        Ok(())
1100    }
1101
1102    async fn update_followers(
1103        self: Arc<Self>,
1104        request: TypedEnvelope<proto::UpdateFollowers>,
1105    ) -> Result<()> {
1106        let connection_ids = self
1107            .store()
1108            .await
1109            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1110        let leader_id = request
1111            .payload
1112            .variant
1113            .as_ref()
1114            .and_then(|variant| match variant {
1115                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1116                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1117                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1118            });
1119        for follower_id in &request.payload.follower_ids {
1120            let follower_id = ConnectionId(*follower_id);
1121            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1122                self.peer
1123                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1124            }
1125        }
1126        Ok(())
1127    }
1128
1129    async fn get_channels(
1130        self: Arc<Server>,
1131        request: TypedEnvelope<proto::GetChannels>,
1132        response: Response<proto::GetChannels>,
1133    ) -> Result<()> {
1134        let user_id = self
1135            .store()
1136            .await
1137            .user_id_for_connection(request.sender_id)?;
1138        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1139        response.send(proto::GetChannelsResponse {
1140            channels: channels
1141                .into_iter()
1142                .map(|chan| proto::Channel {
1143                    id: chan.id.to_proto(),
1144                    name: chan.name,
1145                })
1146                .collect(),
1147        })?;
1148        Ok(())
1149    }
1150
1151    async fn get_users(
1152        self: Arc<Server>,
1153        request: TypedEnvelope<proto::GetUsers>,
1154        response: Response<proto::GetUsers>,
1155    ) -> Result<()> {
1156        let user_ids = request
1157            .payload
1158            .user_ids
1159            .into_iter()
1160            .map(UserId::from_proto)
1161            .collect();
1162        let users = self
1163            .app_state
1164            .db
1165            .get_users_by_ids(user_ids)
1166            .await?
1167            .into_iter()
1168            .map(|user| proto::User {
1169                id: user.id.to_proto(),
1170                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1171                github_login: user.github_login,
1172            })
1173            .collect();
1174        response.send(proto::UsersResponse { users })?;
1175        Ok(())
1176    }
1177
1178    async fn fuzzy_search_users(
1179        self: Arc<Server>,
1180        request: TypedEnvelope<proto::FuzzySearchUsers>,
1181        response: Response<proto::FuzzySearchUsers>,
1182    ) -> Result<()> {
1183        let user_id = self
1184            .store()
1185            .await
1186            .user_id_for_connection(request.sender_id)?;
1187        let query = request.payload.query;
1188        let db = &self.app_state.db;
1189        let users = match query.len() {
1190            0 => vec![],
1191            1 | 2 => db
1192                .get_user_by_github_login(&query)
1193                .await?
1194                .into_iter()
1195                .collect(),
1196            _ => db.fuzzy_search_users(&query, 10).await?,
1197        };
1198        let users = users
1199            .into_iter()
1200            .filter(|user| user.id != user_id)
1201            .map(|user| proto::User {
1202                id: user.id.to_proto(),
1203                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1204                github_login: user.github_login,
1205            })
1206            .collect();
1207        response.send(proto::UsersResponse { users })?;
1208        Ok(())
1209    }
1210
1211    async fn request_contact(
1212        self: Arc<Server>,
1213        request: TypedEnvelope<proto::RequestContact>,
1214        response: Response<proto::RequestContact>,
1215    ) -> Result<()> {
1216        let requester_id = self
1217            .store()
1218            .await
1219            .user_id_for_connection(request.sender_id)?;
1220        let responder_id = UserId::from_proto(request.payload.responder_id);
1221        if requester_id == responder_id {
1222            return Err(anyhow!("cannot add yourself as a contact"))?;
1223        }
1224
1225        self.app_state
1226            .db
1227            .send_contact_request(requester_id, responder_id)
1228            .await?;
1229
1230        // Update outgoing contact requests of requester
1231        let mut update = proto::UpdateContacts::default();
1232        update.outgoing_requests.push(responder_id.to_proto());
1233        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1234            self.peer.send(connection_id, update.clone())?;
1235        }
1236
1237        // Update incoming contact requests of responder
1238        let mut update = proto::UpdateContacts::default();
1239        update
1240            .incoming_requests
1241            .push(proto::IncomingContactRequest {
1242                requester_id: requester_id.to_proto(),
1243                should_notify: true,
1244            });
1245        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1246            self.peer.send(connection_id, update.clone())?;
1247        }
1248
1249        response.send(proto::Ack {})?;
1250        Ok(())
1251    }
1252
1253    async fn respond_to_contact_request(
1254        self: Arc<Server>,
1255        request: TypedEnvelope<proto::RespondToContactRequest>,
1256        response: Response<proto::RespondToContactRequest>,
1257    ) -> Result<()> {
1258        let responder_id = self
1259            .store()
1260            .await
1261            .user_id_for_connection(request.sender_id)?;
1262        let requester_id = UserId::from_proto(request.payload.requester_id);
1263        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1264            self.app_state
1265                .db
1266                .dismiss_contact_notification(responder_id, requester_id)
1267                .await?;
1268        } else {
1269            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1270            self.app_state
1271                .db
1272                .respond_to_contact_request(responder_id, requester_id, accept)
1273                .await?;
1274
1275            let store = self.store().await;
1276            // Update responder with new contact
1277            let mut update = proto::UpdateContacts::default();
1278            if accept {
1279                update
1280                    .contacts
1281                    .push(store.contact_for_user(requester_id, false));
1282            }
1283            update
1284                .remove_incoming_requests
1285                .push(requester_id.to_proto());
1286            for connection_id in store.connection_ids_for_user(responder_id) {
1287                self.peer.send(connection_id, update.clone())?;
1288            }
1289
1290            // Update requester with new contact
1291            let mut update = proto::UpdateContacts::default();
1292            if accept {
1293                update
1294                    .contacts
1295                    .push(store.contact_for_user(responder_id, true));
1296            }
1297            update
1298                .remove_outgoing_requests
1299                .push(responder_id.to_proto());
1300            for connection_id in store.connection_ids_for_user(requester_id) {
1301                self.peer.send(connection_id, update.clone())?;
1302            }
1303        }
1304
1305        response.send(proto::Ack {})?;
1306        Ok(())
1307    }
1308
1309    async fn remove_contact(
1310        self: Arc<Server>,
1311        request: TypedEnvelope<proto::RemoveContact>,
1312        response: Response<proto::RemoveContact>,
1313    ) -> Result<()> {
1314        let requester_id = self
1315            .store()
1316            .await
1317            .user_id_for_connection(request.sender_id)?;
1318        let responder_id = UserId::from_proto(request.payload.user_id);
1319        self.app_state
1320            .db
1321            .remove_contact(requester_id, responder_id)
1322            .await?;
1323
1324        // Update outgoing contact requests of requester
1325        let mut update = proto::UpdateContacts::default();
1326        update
1327            .remove_outgoing_requests
1328            .push(responder_id.to_proto());
1329        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1330            self.peer.send(connection_id, update.clone())?;
1331        }
1332
1333        // Update incoming contact requests of responder
1334        let mut update = proto::UpdateContacts::default();
1335        update
1336            .remove_incoming_requests
1337            .push(requester_id.to_proto());
1338        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1339            self.peer.send(connection_id, update.clone())?;
1340        }
1341
1342        response.send(proto::Ack {})?;
1343        Ok(())
1344    }
1345
1346    async fn join_channel(
1347        self: Arc<Self>,
1348        request: TypedEnvelope<proto::JoinChannel>,
1349        response: Response<proto::JoinChannel>,
1350    ) -> Result<()> {
1351        let user_id = self
1352            .store()
1353            .await
1354            .user_id_for_connection(request.sender_id)?;
1355        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1356        if !self
1357            .app_state
1358            .db
1359            .can_user_access_channel(user_id, channel_id)
1360            .await?
1361        {
1362            Err(anyhow!("access denied"))?;
1363        }
1364
1365        self.store_mut()
1366            .await
1367            .join_channel(request.sender_id, channel_id);
1368        let messages = self
1369            .app_state
1370            .db
1371            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1372            .await?
1373            .into_iter()
1374            .map(|msg| proto::ChannelMessage {
1375                id: msg.id.to_proto(),
1376                body: msg.body,
1377                timestamp: msg.sent_at.unix_timestamp() as u64,
1378                sender_id: msg.sender_id.to_proto(),
1379                nonce: Some(msg.nonce.as_u128().into()),
1380            })
1381            .collect::<Vec<_>>();
1382        response.send(proto::JoinChannelResponse {
1383            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1384            messages,
1385        })?;
1386        Ok(())
1387    }
1388
1389    async fn leave_channel(
1390        self: Arc<Self>,
1391        request: TypedEnvelope<proto::LeaveChannel>,
1392    ) -> Result<()> {
1393        let user_id = self
1394            .store()
1395            .await
1396            .user_id_for_connection(request.sender_id)?;
1397        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1398        if !self
1399            .app_state
1400            .db
1401            .can_user_access_channel(user_id, channel_id)
1402            .await?
1403        {
1404            Err(anyhow!("access denied"))?;
1405        }
1406
1407        self.store_mut()
1408            .await
1409            .leave_channel(request.sender_id, channel_id);
1410
1411        Ok(())
1412    }
1413
1414    async fn send_channel_message(
1415        self: Arc<Self>,
1416        request: TypedEnvelope<proto::SendChannelMessage>,
1417        response: Response<proto::SendChannelMessage>,
1418    ) -> Result<()> {
1419        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1420        let user_id;
1421        let connection_ids;
1422        {
1423            let state = self.store().await;
1424            user_id = state.user_id_for_connection(request.sender_id)?;
1425            connection_ids = state.channel_connection_ids(channel_id)?;
1426        }
1427
1428        // Validate the message body.
1429        let body = request.payload.body.trim().to_string();
1430        if body.len() > MAX_MESSAGE_LEN {
1431            return Err(anyhow!("message is too long"))?;
1432        }
1433        if body.is_empty() {
1434            return Err(anyhow!("message can't be blank"))?;
1435        }
1436
1437        let timestamp = OffsetDateTime::now_utc();
1438        let nonce = request
1439            .payload
1440            .nonce
1441            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1442
1443        let message_id = self
1444            .app_state
1445            .db
1446            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1447            .await?
1448            .to_proto();
1449        let message = proto::ChannelMessage {
1450            sender_id: user_id.to_proto(),
1451            id: message_id,
1452            body,
1453            timestamp: timestamp.unix_timestamp() as u64,
1454            nonce: Some(nonce),
1455        };
1456        broadcast(request.sender_id, connection_ids, |conn_id| {
1457            self.peer.send(
1458                conn_id,
1459                proto::ChannelMessageSent {
1460                    channel_id: channel_id.to_proto(),
1461                    message: Some(message.clone()),
1462                },
1463            )
1464        });
1465        response.send(proto::SendChannelMessageResponse {
1466            message: Some(message),
1467        })?;
1468        Ok(())
1469    }
1470
1471    async fn get_channel_messages(
1472        self: Arc<Self>,
1473        request: TypedEnvelope<proto::GetChannelMessages>,
1474        response: Response<proto::GetChannelMessages>,
1475    ) -> Result<()> {
1476        let user_id = self
1477            .store()
1478            .await
1479            .user_id_for_connection(request.sender_id)?;
1480        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1481        if !self
1482            .app_state
1483            .db
1484            .can_user_access_channel(user_id, channel_id)
1485            .await?
1486        {
1487            Err(anyhow!("access denied"))?;
1488        }
1489
1490        let messages = self
1491            .app_state
1492            .db
1493            .get_channel_messages(
1494                channel_id,
1495                MESSAGE_COUNT_PER_PAGE,
1496                Some(MessageId::from_proto(request.payload.before_message_id)),
1497            )
1498            .await?
1499            .into_iter()
1500            .map(|msg| proto::ChannelMessage {
1501                id: msg.id.to_proto(),
1502                body: msg.body,
1503                timestamp: msg.sent_at.unix_timestamp() as u64,
1504                sender_id: msg.sender_id.to_proto(),
1505                nonce: Some(msg.nonce.as_u128().into()),
1506            })
1507            .collect::<Vec<_>>();
1508        response.send(proto::GetChannelMessagesResponse {
1509            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1510            messages,
1511        })?;
1512        Ok(())
1513    }
1514
1515    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1516        #[cfg(test)]
1517        tokio::task::yield_now().await;
1518        let guard = self.store.read().await;
1519        #[cfg(test)]
1520        tokio::task::yield_now().await;
1521        StoreReadGuard {
1522            guard,
1523            _not_send: PhantomData,
1524        }
1525    }
1526
1527    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1528        #[cfg(test)]
1529        tokio::task::yield_now().await;
1530        let guard = self.store.write().await;
1531        #[cfg(test)]
1532        tokio::task::yield_now().await;
1533        StoreWriteGuard {
1534            guard,
1535            _not_send: PhantomData,
1536        }
1537    }
1538
1539    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1540        ServerSnapshot {
1541            store: self.store.read().await,
1542            peer: &self.peer,
1543        }
1544    }
1545}
1546
1547impl<'a> Deref for StoreReadGuard<'a> {
1548    type Target = Store;
1549
1550    fn deref(&self) -> &Self::Target {
1551        &*self.guard
1552    }
1553}
1554
1555impl<'a> Deref for StoreWriteGuard<'a> {
1556    type Target = Store;
1557
1558    fn deref(&self) -> &Self::Target {
1559        &*self.guard
1560    }
1561}
1562
1563impl<'a> DerefMut for StoreWriteGuard<'a> {
1564    fn deref_mut(&mut self) -> &mut Self::Target {
1565        &mut *self.guard
1566    }
1567}
1568
1569impl<'a> Drop for StoreWriteGuard<'a> {
1570    fn drop(&mut self) {
1571        #[cfg(test)]
1572        self.check_invariants();
1573
1574        let metrics = self.metrics();
1575
1576        METRIC_CONNECTIONS.set(metrics.connections as _);
1577        METRIC_PROJECTS.set(metrics.registered_projects as _);
1578        METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1579
1580        tracing::info!(
1581            connections = metrics.connections,
1582            registered_projects = metrics.registered_projects,
1583            shared_projects = metrics.shared_projects,
1584            "metrics"
1585        );
1586    }
1587}
1588
1589impl Executor for RealExecutor {
1590    type Sleep = Sleep;
1591
1592    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1593        tokio::task::spawn(future);
1594    }
1595
1596    fn sleep(&self, duration: Duration) -> Self::Sleep {
1597        tokio::time::sleep(duration)
1598    }
1599}
1600
1601fn broadcast<F>(
1602    sender_id: ConnectionId,
1603    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1604    mut f: F,
1605) where
1606    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1607{
1608    for receiver_id in receiver_ids {
1609        if receiver_id != sender_id {
1610            f(receiver_id).trace_err();
1611        }
1612    }
1613}
1614
1615lazy_static! {
1616    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1617}
1618
1619pub struct ProtocolVersion(u32);
1620
1621impl Header for ProtocolVersion {
1622    fn name() -> &'static HeaderName {
1623        &ZED_PROTOCOL_VERSION
1624    }
1625
1626    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1627    where
1628        Self: Sized,
1629        I: Iterator<Item = &'i axum::http::HeaderValue>,
1630    {
1631        let version = values
1632            .next()
1633            .ok_or_else(|| axum::headers::Error::invalid())?
1634            .to_str()
1635            .map_err(|_| axum::headers::Error::invalid())?
1636            .parse()
1637            .map_err(|_| axum::headers::Error::invalid())?;
1638        Ok(Self(version))
1639    }
1640
1641    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1642        values.extend([self.0.to_string().parse().unwrap()]);
1643    }
1644}
1645
1646pub fn routes(server: Arc<Server>) -> Router<Body> {
1647    Router::new()
1648        .route("/rpc", get(handle_websocket_request))
1649        .layer(
1650            ServiceBuilder::new()
1651                .layer(Extension(server.app_state.clone()))
1652                .layer(middleware::from_fn(auth::validate_header))
1653                .layer(Extension(server)),
1654        )
1655        .route("/metrics", get(handle_metrics))
1656}
1657
1658pub async fn handle_websocket_request(
1659    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1660    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1661    Extension(server): Extension<Arc<Server>>,
1662    Extension(user): Extension<User>,
1663    ws: WebSocketUpgrade,
1664) -> axum::response::Response {
1665    if protocol_version != rpc::PROTOCOL_VERSION {
1666        return (
1667            StatusCode::UPGRADE_REQUIRED,
1668            "client must be upgraded".to_string(),
1669        )
1670            .into_response();
1671    }
1672    let socket_address = socket_address.to_string();
1673    ws.on_upgrade(move |socket| {
1674        use util::ResultExt;
1675        let socket = socket
1676            .map_ok(to_tungstenite_message)
1677            .err_into()
1678            .with(|message| async move { Ok(to_axum_message(message)) });
1679        let connection = Connection::new(Box::pin(socket));
1680        async move {
1681            server
1682                .handle_connection(connection, socket_address, user, None, RealExecutor)
1683                .await
1684                .log_err();
1685        }
1686    })
1687}
1688
1689pub async fn handle_metrics() -> axum::response::Response {
1690    let encoder = prometheus::TextEncoder::new();
1691    let metric_families = prometheus::gather();
1692    match encoder.encode_to_string(&metric_families) {
1693        Ok(string) => (StatusCode::OK, string).into_response(),
1694        Err(error) => (
1695            StatusCode::INTERNAL_SERVER_ERROR,
1696            format!("failed to encode metrics {:?}", error),
1697        )
1698            .into_response(),
1699    }
1700}
1701
1702fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1703    match message {
1704        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1705        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1706        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1707        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1708        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1709            code: frame.code.into(),
1710            reason: frame.reason,
1711        })),
1712    }
1713}
1714
1715fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1716    match message {
1717        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1718        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1719        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1720        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1721        AxumMessage::Close(frame) => {
1722            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1723                code: frame.code.into(),
1724                reason: frame.reason,
1725            }))
1726        }
1727    }
1728}
1729
1730pub trait ResultExt {
1731    type Ok;
1732
1733    fn trace_err(self) -> Option<Self::Ok>;
1734}
1735
1736impl<T, E> ResultExt for Result<T, E>
1737where
1738    E: std::fmt::Debug,
1739{
1740    type Ok = T;
1741
1742    fn trace_err(self) -> Option<T> {
1743        match self {
1744            Ok(value) => Some(value),
1745            Err(error) => {
1746                tracing::error!("{:?}", error);
1747                None
1748            }
1749        }
1750    }
1751}