rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, MessageId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::HashMap;
  26use futures::{
  27    channel::mpsc,
  28    future::{self, BoxFuture},
  29    FutureExt, SinkExt, StreamExt, TryStreamExt,
  30};
  31use lazy_static::lazy_static;
  32use rpc::{
  33    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  34    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  35};
  36use serde::{Serialize, Serializer};
  37use std::{
  38    any::TypeId,
  39    future::Future,
  40    marker::PhantomData,
  41    net::SocketAddr,
  42    ops::{Deref, DerefMut},
  43    rc::Rc,
  44    sync::{
  45        atomic::{AtomicBool, Ordering::SeqCst},
  46        Arc,
  47    },
  48    time::Duration,
  49};
  50use time::OffsetDateTime;
  51use tokio::{
  52    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
  53    time::Sleep,
  54};
  55use tower::ServiceBuilder;
  56use tracing::{info_span, instrument, Instrument};
  57
  58pub use store::{Store, Worktree};
  59
  60type MessageHandler =
  61    Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
  62
  63struct Response<R> {
  64    server: Arc<Server>,
  65    receipt: Receipt<R>,
  66    responded: Arc<AtomicBool>,
  67}
  68
  69impl<R: RequestMessage> Response<R> {
  70    fn send(self, payload: R::Response) -> Result<()> {
  71        self.responded.store(true, SeqCst);
  72        self.server.peer.respond(self.receipt, payload)?;
  73        Ok(())
  74    }
  75
  76    fn into_receipt(self) -> Receipt<R> {
  77        self.responded.store(true, SeqCst);
  78        self.receipt
  79    }
  80}
  81
  82pub struct Server {
  83    peer: Arc<Peer>,
  84    pub(crate) store: RwLock<Store>,
  85    app_state: Arc<AppState>,
  86    handlers: HashMap<TypeId, MessageHandler>,
  87    notifications: Option<mpsc::UnboundedSender<()>>,
  88}
  89
  90pub trait Executor: Send + Clone {
  91    type Sleep: Send + Future;
  92    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
  93    fn sleep(&self, duration: Duration) -> Self::Sleep;
  94}
  95
  96#[derive(Clone)]
  97pub struct RealExecutor;
  98
  99const MESSAGE_COUNT_PER_PAGE: usize = 100;
 100const MAX_MESSAGE_LEN: usize = 1024;
 101
 102struct StoreReadGuard<'a> {
 103    guard: RwLockReadGuard<'a, Store>,
 104    _not_send: PhantomData<Rc<()>>,
 105}
 106
 107struct StoreWriteGuard<'a> {
 108    guard: RwLockWriteGuard<'a, Store>,
 109    _not_send: PhantomData<Rc<()>>,
 110}
 111
 112#[derive(Serialize)]
 113pub struct ServerSnapshot<'a> {
 114    peer: &'a Peer,
 115    #[serde(serialize_with = "serialize_deref")]
 116    store: RwLockReadGuard<'a, Store>,
 117}
 118
 119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 120where
 121    S: Serializer,
 122    T: Deref<Target = U>,
 123    U: Serialize,
 124{
 125    Serialize::serialize(value.deref(), serializer)
 126}
 127
 128impl Server {
 129    pub fn new(
 130        app_state: Arc<AppState>,
 131        notifications: Option<mpsc::UnboundedSender<()>>,
 132    ) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138            notifications,
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::register_project)
 144            .add_message_handler(Server::unregister_project)
 145            .add_request_handler(Server::join_project)
 146            .add_message_handler(Server::leave_project)
 147            .add_message_handler(Server::respond_to_join_project_request)
 148            .add_request_handler(Server::register_worktree)
 149            .add_message_handler(Server::unregister_worktree)
 150            .add_request_handler(Server::update_worktree)
 151            .add_message_handler(Server::start_language_server)
 152            .add_message_handler(Server::update_language_server)
 153            .add_message_handler(Server::update_diagnostic_summary)
 154            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 155            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 156            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 157            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 158            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 159            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 160            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 161            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 163            .add_request_handler(
 164                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 165            )
 166            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 167            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 168            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 169            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 170            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 171            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 172            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 173            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 174            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 175            .add_request_handler(Server::update_buffer)
 176            .add_message_handler(Server::update_buffer_file)
 177            .add_message_handler(Server::buffer_reloaded)
 178            .add_message_handler(Server::buffer_saved)
 179            .add_request_handler(Server::save_buffer)
 180            .add_request_handler(Server::get_channels)
 181            .add_request_handler(Server::get_users)
 182            .add_request_handler(Server::fuzzy_search_users)
 183            .add_request_handler(Server::request_contact)
 184            .add_request_handler(Server::remove_contact)
 185            .add_request_handler(Server::respond_to_contact_request)
 186            .add_request_handler(Server::join_channel)
 187            .add_message_handler(Server::leave_channel)
 188            .add_request_handler(Server::send_channel_message)
 189            .add_request_handler(Server::follow)
 190            .add_message_handler(Server::unfollow)
 191            .add_message_handler(Server::update_followers)
 192            .add_request_handler(Server::get_channel_messages);
 193
 194        Arc::new(server)
 195    }
 196
 197    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 198    where
 199        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 200        Fut: 'static + Send + Future<Output = Result<()>>,
 201        M: EnvelopedMessage,
 202    {
 203        let prev_handler = self.handlers.insert(
 204            TypeId::of::<M>(),
 205            Box::new(move |server, envelope| {
 206                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 207                let span = info_span!(
 208                    "handle message",
 209                    payload_type = envelope.payload_type_name(),
 210                    payload = format!("{:?}", envelope.payload).as_str(),
 211                );
 212                let future = (handler)(server, *envelope);
 213                async move {
 214                    if let Err(error) = future.await {
 215                        tracing::error!(%error, "error handling message");
 216                    }
 217                }
 218                .instrument(span)
 219                .boxed()
 220            }),
 221        );
 222        if prev_handler.is_some() {
 223            panic!("registered a handler for the same message twice");
 224        }
 225        self
 226    }
 227
 228    /// Handle a request while holding a lock to the store. This is useful when we're registering
 229    /// a connection but we want to respond on the connection before anybody else can send on it.
 230    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 231    where
 232        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
 233        Fut: Send + Future<Output = Result<()>>,
 234        M: RequestMessage,
 235    {
 236        let handler = Arc::new(handler);
 237        self.add_message_handler(move |server, envelope| {
 238            let receipt = envelope.receipt();
 239            let handler = handler.clone();
 240            async move {
 241                let responded = Arc::new(AtomicBool::default());
 242                let response = Response {
 243                    server: server.clone(),
 244                    responded: responded.clone(),
 245                    receipt: envelope.receipt(),
 246                };
 247                match (handler)(server.clone(), envelope, response).await {
 248                    Ok(()) => {
 249                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 250                            Ok(())
 251                        } else {
 252                            Err(anyhow!("handler did not send a response"))?
 253                        }
 254                    }
 255                    Err(error) => {
 256                        server.peer.respond_with_error(
 257                            receipt,
 258                            proto::Error {
 259                                message: error.to_string(),
 260                            },
 261                        )?;
 262                        Err(error)
 263                    }
 264                }
 265            }
 266        })
 267    }
 268
 269    pub fn handle_connection<E: Executor>(
 270        self: &Arc<Self>,
 271        connection: Connection,
 272        address: String,
 273        user: User,
 274        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
 275        executor: E,
 276    ) -> impl Future<Output = Result<()>> {
 277        let mut this = self.clone();
 278        let user_id = user.id;
 279        let login = user.github_login;
 280        let span = info_span!("handle connection", %user_id, %login, %address);
 281        async move {
 282            let (connection_id, handle_io, mut incoming_rx) = this
 283                .peer
 284                .add_connection(connection, {
 285                    let executor = executor.clone();
 286                    move |duration| {
 287                        let timer = executor.sleep(duration);
 288                        async move {
 289                            timer.await;
 290                        }
 291                    }
 292                })
 293                .await;
 294
 295            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 296
 297            if let Some(send_connection_id) = send_connection_id.as_mut() {
 298                let _ = send_connection_id.send(connection_id).await;
 299            }
 300
 301            if !user.connected_once {
 302                this.peer.send(connection_id, proto::ShowContacts {})?;
 303                this.app_state.db.set_user_connected_once(user_id, true).await?;
 304            }
 305
 306            let (contacts, invite_code) = future::try_join(
 307                this.app_state.db.get_contacts(user_id),
 308                this.app_state.db.get_invite_code_for_user(user_id)
 309            ).await?;
 310
 311            {
 312                let mut store = this.store_mut().await;
 313                store.add_connection(connection_id, user_id);
 314                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 315
 316                if let Some((code, count)) = invite_code {
 317                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 318                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
 319                        count,
 320                    })?;
 321                }
 322            }
 323            this.update_user_contacts(user_id).await?;
 324
 325            let handle_io = handle_io.fuse();
 326            futures::pin_mut!(handle_io);
 327            loop {
 328                let next_message = incoming_rx.next().fuse();
 329                futures::pin_mut!(next_message);
 330                futures::select_biased! {
 331                    result = handle_io => {
 332                        if let Err(error) = result {
 333                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 334                        }
 335                        break;
 336                    }
 337                    message = next_message => {
 338                        if let Some(message) = message {
 339                            let type_name = message.payload_type_name();
 340                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 341                            async {
 342                                if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 343                                    let notifications = this.notifications.clone();
 344                                    let is_background = message.is_background();
 345                                    let handle_message = (handler)(this.clone(), message);
 346                                    let handle_message = async move {
 347                                        handle_message.await;
 348                                        if let Some(mut notifications) = notifications {
 349                                            let _ = notifications.send(()).await;
 350                                        }
 351                                    };
 352                                    if is_background {
 353                                        executor.spawn_detached(handle_message);
 354                                    } else {
 355                                        handle_message.await;
 356                                    }
 357                                } else {
 358                                    tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 359                                }
 360                            }.instrument(span).await;
 361                        } else {
 362                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 363                            break;
 364                        }
 365                    }
 366                }
 367            }
 368
 369            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 370            if let Err(error) = this.sign_out(connection_id).await {
 371                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 372            }
 373
 374            Ok(())
 375        }.instrument(span)
 376    }
 377
 378    #[instrument(skip(self), err)]
 379    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
 380        self.peer.disconnect(connection_id);
 381
 382        let removed_user_id = {
 383            let mut store = self.store_mut().await;
 384            let removed_connection = store.remove_connection(connection_id)?;
 385
 386            for (project_id, project) in removed_connection.hosted_projects {
 387                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
 388                    self.peer
 389                        .send(conn_id, proto::UnregisterProject { project_id })
 390                });
 391
 392                for (_, receipts) in project.join_requests {
 393                    for receipt in receipts {
 394                        self.peer.respond(
 395                            receipt,
 396                            proto::JoinProjectResponse {
 397                                variant: Some(proto::join_project_response::Variant::Decline(
 398                                    proto::join_project_response::Decline {
 399                                        reason: proto::join_project_response::decline::Reason::WentOffline as i32
 400                                    },
 401                                )),
 402                            },
 403                        )?;
 404                    }
 405                }
 406            }
 407
 408            for project_id in removed_connection.guest_project_ids {
 409                if let Some(project) = store.project(project_id).trace_err() {
 410                    broadcast(connection_id, project.connection_ids(), |conn_id| {
 411                        self.peer.send(
 412                            conn_id,
 413                            proto::RemoveProjectCollaborator {
 414                                project_id,
 415                                peer_id: connection_id.0,
 416                            },
 417                        )
 418                    });
 419                    if project.guests.is_empty() {
 420                        self.peer
 421                            .send(
 422                                project.host_connection_id,
 423                                proto::ProjectUnshared { project_id },
 424                            )
 425                            .trace_err();
 426                    }
 427                }
 428            }
 429
 430            removed_connection.user_id
 431        };
 432
 433        self.update_user_contacts(removed_user_id).await?;
 434
 435        Ok(())
 436    }
 437
 438    pub async fn invite_code_redeemed(
 439        self: &Arc<Self>,
 440        code: &str,
 441        invitee_id: UserId,
 442    ) -> Result<()> {
 443        let user = self.app_state.db.get_user_for_invite_code(code).await?;
 444        let store = self.store().await;
 445        let invitee_contact = store.contact_for_user(invitee_id, true);
 446        for connection_id in store.connection_ids_for_user(user.id) {
 447            self.peer.send(
 448                connection_id,
 449                proto::UpdateContacts {
 450                    contacts: vec![invitee_contact.clone()],
 451                    ..Default::default()
 452                },
 453            )?;
 454            self.peer.send(
 455                connection_id,
 456                proto::UpdateInviteInfo {
 457                    url: format!("{}{}", self.app_state.invite_link_prefix, code),
 458                    count: user.invite_count as u32,
 459                },
 460            )?;
 461        }
 462        Ok(())
 463    }
 464
 465    async fn ping(
 466        self: Arc<Server>,
 467        _: TypedEnvelope<proto::Ping>,
 468        response: Response<proto::Ping>,
 469    ) -> Result<()> {
 470        response.send(proto::Ack {})?;
 471        Ok(())
 472    }
 473
 474    async fn register_project(
 475        self: Arc<Server>,
 476        request: TypedEnvelope<proto::RegisterProject>,
 477        response: Response<proto::RegisterProject>,
 478    ) -> Result<()> {
 479        let user_id;
 480        let project_id;
 481        {
 482            let mut state = self.store_mut().await;
 483            user_id = state.user_id_for_connection(request.sender_id)?;
 484            project_id = state.register_project(request.sender_id, user_id);
 485        };
 486        self.update_user_contacts(user_id).await?;
 487        response.send(proto::RegisterProjectResponse { project_id })?;
 488        Ok(())
 489    }
 490
 491    async fn unregister_project(
 492        self: Arc<Server>,
 493        request: TypedEnvelope<proto::UnregisterProject>,
 494    ) -> Result<()> {
 495        let (user_id, project) = {
 496            let mut state = self.store_mut().await;
 497            let project =
 498                state.unregister_project(request.payload.project_id, request.sender_id)?;
 499            (state.user_id_for_connection(request.sender_id)?, project)
 500        };
 501
 502        broadcast(
 503            request.sender_id,
 504            project.guests.keys().copied(),
 505            |conn_id| {
 506                self.peer.send(
 507                    conn_id,
 508                    proto::UnregisterProject {
 509                        project_id: request.payload.project_id,
 510                    },
 511                )
 512            },
 513        );
 514        for (_, receipts) in project.join_requests {
 515            for receipt in receipts {
 516                self.peer.respond(
 517                    receipt,
 518                    proto::JoinProjectResponse {
 519                        variant: Some(proto::join_project_response::Variant::Decline(
 520                            proto::join_project_response::Decline {
 521                                reason: proto::join_project_response::decline::Reason::Closed
 522                                    as i32,
 523                            },
 524                        )),
 525                    },
 526                )?;
 527            }
 528        }
 529
 530        self.update_user_contacts(user_id).await?;
 531        Ok(())
 532    }
 533
 534    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 535        let contacts = self.app_state.db.get_contacts(user_id).await?;
 536        let store = self.store().await;
 537        let updated_contact = store.contact_for_user(user_id, false);
 538        for contact in contacts {
 539            if let db::Contact::Accepted {
 540                user_id: contact_user_id,
 541                ..
 542            } = contact
 543            {
 544                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 545                    self.peer
 546                        .send(
 547                            contact_conn_id,
 548                            proto::UpdateContacts {
 549                                contacts: vec![updated_contact.clone()],
 550                                remove_contacts: Default::default(),
 551                                incoming_requests: Default::default(),
 552                                remove_incoming_requests: Default::default(),
 553                                outgoing_requests: Default::default(),
 554                                remove_outgoing_requests: Default::default(),
 555                            },
 556                        )
 557                        .trace_err();
 558                }
 559            }
 560        }
 561        Ok(())
 562    }
 563
 564    async fn join_project(
 565        self: Arc<Server>,
 566        request: TypedEnvelope<proto::JoinProject>,
 567        response: Response<proto::JoinProject>,
 568    ) -> Result<()> {
 569        let project_id = request.payload.project_id;
 570        let host_user_id;
 571        let guest_user_id;
 572        let host_connection_id;
 573        {
 574            let state = self.store().await;
 575            let project = state.project(project_id)?;
 576            host_user_id = project.host_user_id;
 577            host_connection_id = project.host_connection_id;
 578            guest_user_id = state.user_id_for_connection(request.sender_id)?;
 579        };
 580
 581        let has_contact = self
 582            .app_state
 583            .db
 584            .has_contact(guest_user_id, host_user_id)
 585            .await?;
 586        if !has_contact {
 587            return Err(anyhow!("no such project"))?;
 588        }
 589
 590        self.store_mut().await.request_join_project(
 591            guest_user_id,
 592            project_id,
 593            response.into_receipt(),
 594        )?;
 595        self.peer.send(
 596            host_connection_id,
 597            proto::RequestJoinProject {
 598                project_id,
 599                requester_id: guest_user_id.to_proto(),
 600            },
 601        )?;
 602        Ok(())
 603    }
 604
 605    async fn respond_to_join_project_request(
 606        self: Arc<Server>,
 607        request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
 608    ) -> Result<()> {
 609        let host_user_id;
 610
 611        {
 612            let mut state = self.store_mut().await;
 613            let project_id = request.payload.project_id;
 614            let project = state.project(project_id)?;
 615            if project.host_connection_id != request.sender_id {
 616                Err(anyhow!("no such connection"))?;
 617            }
 618
 619            host_user_id = project.host_user_id;
 620            let guest_user_id = UserId::from_proto(request.payload.requester_id);
 621
 622            if !request.payload.allow {
 623                let receipts = state
 624                    .deny_join_project_request(request.sender_id, guest_user_id, project_id)
 625                    .ok_or_else(|| anyhow!("no such request"))?;
 626                for receipt in receipts {
 627                    self.peer.respond(
 628                        receipt,
 629                        proto::JoinProjectResponse {
 630                            variant: Some(proto::join_project_response::Variant::Decline(
 631                                proto::join_project_response::Decline {
 632                                    reason: proto::join_project_response::decline::Reason::Declined
 633                                        as i32,
 634                                },
 635                            )),
 636                        },
 637                    )?;
 638                }
 639                return Ok(());
 640            }
 641
 642            let (receipts_with_replica_ids, project) = state
 643                .accept_join_project_request(request.sender_id, guest_user_id, project_id)
 644                .ok_or_else(|| anyhow!("no such request"))?;
 645
 646            let peer_count = project.guests.len();
 647            let mut collaborators = Vec::with_capacity(peer_count);
 648            collaborators.push(proto::Collaborator {
 649                peer_id: project.host_connection_id.0,
 650                replica_id: 0,
 651                user_id: project.host_user_id.to_proto(),
 652            });
 653            let worktrees = project
 654                .worktrees
 655                .iter()
 656                .filter_map(|(id, shared_worktree)| {
 657                    let worktree = project.worktrees.get(&id)?;
 658                    Some(proto::Worktree {
 659                        id: *id,
 660                        root_name: worktree.root_name.clone(),
 661                        entries: shared_worktree.entries.values().cloned().collect(),
 662                        diagnostic_summaries: shared_worktree
 663                            .diagnostic_summaries
 664                            .values()
 665                            .cloned()
 666                            .collect(),
 667                        visible: worktree.visible,
 668                        scan_id: shared_worktree.scan_id,
 669                    })
 670                })
 671                .collect::<Vec<_>>();
 672
 673            // Add all guests other than the requesting user's own connections as collaborators
 674            for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
 675                if receipts_with_replica_ids
 676                    .iter()
 677                    .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
 678                {
 679                    collaborators.push(proto::Collaborator {
 680                        peer_id: peer_conn_id.0,
 681                        replica_id: *peer_replica_id as u32,
 682                        user_id: peer_user_id.to_proto(),
 683                    });
 684                }
 685            }
 686
 687            for conn_id in project.connection_ids() {
 688                for (receipt, replica_id) in &receipts_with_replica_ids {
 689                    if conn_id != receipt.sender_id {
 690                        self.peer.send(
 691                            conn_id,
 692                            proto::AddProjectCollaborator {
 693                                project_id,
 694                                collaborator: Some(proto::Collaborator {
 695                                    peer_id: receipt.sender_id.0,
 696                                    replica_id: *replica_id as u32,
 697                                    user_id: guest_user_id.to_proto(),
 698                                }),
 699                            },
 700                        )?;
 701                    }
 702                }
 703            }
 704
 705            for (receipt, replica_id) in receipts_with_replica_ids {
 706                self.peer.respond(
 707                    receipt,
 708                    proto::JoinProjectResponse {
 709                        variant: Some(proto::join_project_response::Variant::Accept(
 710                            proto::join_project_response::Accept {
 711                                worktrees: worktrees.clone(),
 712                                replica_id: replica_id as u32,
 713                                collaborators: collaborators.clone(),
 714                                language_servers: project.language_servers.clone(),
 715                            },
 716                        )),
 717                    },
 718                )?;
 719            }
 720        }
 721
 722        self.update_user_contacts(host_user_id).await?;
 723        Ok(())
 724    }
 725
 726    async fn leave_project(
 727        self: Arc<Server>,
 728        request: TypedEnvelope<proto::LeaveProject>,
 729    ) -> Result<()> {
 730        let sender_id = request.sender_id;
 731        let project_id = request.payload.project_id;
 732        let project;
 733        {
 734            let mut store = self.store_mut().await;
 735            project = store.leave_project(sender_id, project_id)?;
 736
 737            if project.remove_collaborator {
 738                broadcast(sender_id, project.connection_ids, |conn_id| {
 739                    self.peer.send(
 740                        conn_id,
 741                        proto::RemoveProjectCollaborator {
 742                            project_id,
 743                            peer_id: sender_id.0,
 744                        },
 745                    )
 746                });
 747            }
 748
 749            if let Some(requester_id) = project.cancel_request {
 750                self.peer.send(
 751                    project.host_connection_id,
 752                    proto::JoinProjectRequestCancelled {
 753                        project_id,
 754                        requester_id: requester_id.to_proto(),
 755                    },
 756                )?;
 757            }
 758
 759            if project.unshare {
 760                self.peer.send(
 761                    project.host_connection_id,
 762                    proto::ProjectUnshared { project_id },
 763                )?;
 764            }
 765        }
 766        self.update_user_contacts(project.host_user_id).await?;
 767        Ok(())
 768    }
 769
 770    async fn register_worktree(
 771        self: Arc<Server>,
 772        request: TypedEnvelope<proto::RegisterWorktree>,
 773        response: Response<proto::RegisterWorktree>,
 774    ) -> Result<()> {
 775        let host_user_id;
 776        {
 777            let mut state = self.store_mut().await;
 778            host_user_id = state.user_id_for_connection(request.sender_id)?;
 779
 780            let guest_connection_ids = state
 781                .read_project(request.payload.project_id, request.sender_id)?
 782                .guest_connection_ids();
 783            state.register_worktree(
 784                request.payload.project_id,
 785                request.payload.worktree_id,
 786                request.sender_id,
 787                Worktree {
 788                    root_name: request.payload.root_name.clone(),
 789                    visible: request.payload.visible,
 790                    ..Default::default()
 791                },
 792            )?;
 793
 794            broadcast(request.sender_id, guest_connection_ids, |connection_id| {
 795                self.peer
 796                    .forward_send(request.sender_id, connection_id, request.payload.clone())
 797            });
 798        }
 799        self.update_user_contacts(host_user_id).await?;
 800        response.send(proto::Ack {})?;
 801        Ok(())
 802    }
 803
 804    async fn unregister_worktree(
 805        self: Arc<Server>,
 806        request: TypedEnvelope<proto::UnregisterWorktree>,
 807    ) -> Result<()> {
 808        let host_user_id;
 809        let project_id = request.payload.project_id;
 810        let worktree_id = request.payload.worktree_id;
 811        {
 812            let mut state = self.store_mut().await;
 813            let (_, guest_connection_ids) =
 814                state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
 815            host_user_id = state.user_id_for_connection(request.sender_id)?;
 816            broadcast(request.sender_id, guest_connection_ids, |conn_id| {
 817                self.peer.send(
 818                    conn_id,
 819                    proto::UnregisterWorktree {
 820                        project_id,
 821                        worktree_id,
 822                    },
 823                )
 824            });
 825        }
 826        self.update_user_contacts(host_user_id).await?;
 827        Ok(())
 828    }
 829
 830    async fn update_worktree(
 831        self: Arc<Server>,
 832        request: TypedEnvelope<proto::UpdateWorktree>,
 833        response: Response<proto::UpdateWorktree>,
 834    ) -> Result<()> {
 835        let connection_ids = self.store_mut().await.update_worktree(
 836            request.sender_id,
 837            request.payload.project_id,
 838            request.payload.worktree_id,
 839            &request.payload.removed_entries,
 840            &request.payload.updated_entries,
 841            request.payload.scan_id,
 842        )?;
 843
 844        broadcast(request.sender_id, connection_ids, |connection_id| {
 845            self.peer
 846                .forward_send(request.sender_id, connection_id, request.payload.clone())
 847        });
 848        response.send(proto::Ack {})?;
 849        Ok(())
 850    }
 851
 852    async fn update_diagnostic_summary(
 853        self: Arc<Server>,
 854        request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
 855    ) -> Result<()> {
 856        let summary = request
 857            .payload
 858            .summary
 859            .clone()
 860            .ok_or_else(|| anyhow!("invalid summary"))?;
 861        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
 862            request.payload.project_id,
 863            request.payload.worktree_id,
 864            request.sender_id,
 865            summary,
 866        )?;
 867
 868        broadcast(request.sender_id, receiver_ids, |connection_id| {
 869            self.peer
 870                .forward_send(request.sender_id, connection_id, request.payload.clone())
 871        });
 872        Ok(())
 873    }
 874
 875    async fn start_language_server(
 876        self: Arc<Server>,
 877        request: TypedEnvelope<proto::StartLanguageServer>,
 878    ) -> Result<()> {
 879        let receiver_ids = self.store_mut().await.start_language_server(
 880            request.payload.project_id,
 881            request.sender_id,
 882            request
 883                .payload
 884                .server
 885                .clone()
 886                .ok_or_else(|| anyhow!("invalid language server"))?,
 887        )?;
 888        broadcast(request.sender_id, receiver_ids, |connection_id| {
 889            self.peer
 890                .forward_send(request.sender_id, connection_id, request.payload.clone())
 891        });
 892        Ok(())
 893    }
 894
 895    async fn update_language_server(
 896        self: Arc<Server>,
 897        request: TypedEnvelope<proto::UpdateLanguageServer>,
 898    ) -> Result<()> {
 899        let receiver_ids = self
 900            .store()
 901            .await
 902            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 903        broadcast(request.sender_id, receiver_ids, |connection_id| {
 904            self.peer
 905                .forward_send(request.sender_id, connection_id, request.payload.clone())
 906        });
 907        Ok(())
 908    }
 909
 910    async fn forward_project_request<T>(
 911        self: Arc<Server>,
 912        request: TypedEnvelope<T>,
 913        response: Response<T>,
 914    ) -> Result<()>
 915    where
 916        T: EntityMessage + RequestMessage,
 917    {
 918        let host_connection_id = self
 919            .store()
 920            .await
 921            .read_project(request.payload.remote_entity_id(), request.sender_id)?
 922            .host_connection_id;
 923
 924        response.send(
 925            self.peer
 926                .forward_request(request.sender_id, host_connection_id, request.payload)
 927                .await?,
 928        )?;
 929        Ok(())
 930    }
 931
 932    async fn save_buffer(
 933        self: Arc<Server>,
 934        request: TypedEnvelope<proto::SaveBuffer>,
 935        response: Response<proto::SaveBuffer>,
 936    ) -> Result<()> {
 937        let host = self
 938            .store()
 939            .await
 940            .read_project(request.payload.project_id, request.sender_id)?
 941            .host_connection_id;
 942        let response_payload = self
 943            .peer
 944            .forward_request(request.sender_id, host, request.payload.clone())
 945            .await?;
 946
 947        let mut guests = self
 948            .store()
 949            .await
 950            .read_project(request.payload.project_id, request.sender_id)?
 951            .connection_ids();
 952        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
 953        broadcast(host, guests, |conn_id| {
 954            self.peer
 955                .forward_send(host, conn_id, response_payload.clone())
 956        });
 957        response.send(response_payload)?;
 958        Ok(())
 959    }
 960
 961    async fn update_buffer(
 962        self: Arc<Server>,
 963        request: TypedEnvelope<proto::UpdateBuffer>,
 964        response: Response<proto::UpdateBuffer>,
 965    ) -> Result<()> {
 966        let receiver_ids = self
 967            .store()
 968            .await
 969            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 970        broadcast(request.sender_id, receiver_ids, |connection_id| {
 971            self.peer
 972                .forward_send(request.sender_id, connection_id, request.payload.clone())
 973        });
 974        response.send(proto::Ack {})?;
 975        Ok(())
 976    }
 977
 978    async fn update_buffer_file(
 979        self: Arc<Server>,
 980        request: TypedEnvelope<proto::UpdateBufferFile>,
 981    ) -> Result<()> {
 982        let receiver_ids = self
 983            .store()
 984            .await
 985            .project_connection_ids(request.payload.project_id, request.sender_id)?;
 986        broadcast(request.sender_id, receiver_ids, |connection_id| {
 987            self.peer
 988                .forward_send(request.sender_id, connection_id, request.payload.clone())
 989        });
 990        Ok(())
 991    }
 992
 993    async fn buffer_reloaded(
 994        self: Arc<Server>,
 995        request: TypedEnvelope<proto::BufferReloaded>,
 996    ) -> Result<()> {
 997        let receiver_ids = self
 998            .store()
 999            .await
1000            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1001        broadcast(request.sender_id, receiver_ids, |connection_id| {
1002            self.peer
1003                .forward_send(request.sender_id, connection_id, request.payload.clone())
1004        });
1005        Ok(())
1006    }
1007
1008    async fn buffer_saved(
1009        self: Arc<Server>,
1010        request: TypedEnvelope<proto::BufferSaved>,
1011    ) -> Result<()> {
1012        let receiver_ids = self
1013            .store()
1014            .await
1015            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1016        broadcast(request.sender_id, receiver_ids, |connection_id| {
1017            self.peer
1018                .forward_send(request.sender_id, connection_id, request.payload.clone())
1019        });
1020        Ok(())
1021    }
1022
1023    async fn follow(
1024        self: Arc<Self>,
1025        request: TypedEnvelope<proto::Follow>,
1026        response: Response<proto::Follow>,
1027    ) -> Result<()> {
1028        let leader_id = ConnectionId(request.payload.leader_id);
1029        let follower_id = request.sender_id;
1030        if !self
1031            .store()
1032            .await
1033            .project_connection_ids(request.payload.project_id, follower_id)?
1034            .contains(&leader_id)
1035        {
1036            Err(anyhow!("no such peer"))?;
1037        }
1038        let mut response_payload = self
1039            .peer
1040            .forward_request(request.sender_id, leader_id, request.payload)
1041            .await?;
1042        response_payload
1043            .views
1044            .retain(|view| view.leader_id != Some(follower_id.0));
1045        response.send(response_payload)?;
1046        Ok(())
1047    }
1048
1049    async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1050        let leader_id = ConnectionId(request.payload.leader_id);
1051        if !self
1052            .store()
1053            .await
1054            .project_connection_ids(request.payload.project_id, request.sender_id)?
1055            .contains(&leader_id)
1056        {
1057            Err(anyhow!("no such peer"))?;
1058        }
1059        self.peer
1060            .forward_send(request.sender_id, leader_id, request.payload)?;
1061        Ok(())
1062    }
1063
1064    async fn update_followers(
1065        self: Arc<Self>,
1066        request: TypedEnvelope<proto::UpdateFollowers>,
1067    ) -> Result<()> {
1068        let connection_ids = self
1069            .store()
1070            .await
1071            .project_connection_ids(request.payload.project_id, request.sender_id)?;
1072        let leader_id = request
1073            .payload
1074            .variant
1075            .as_ref()
1076            .and_then(|variant| match variant {
1077                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1078                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1079                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1080            });
1081        for follower_id in &request.payload.follower_ids {
1082            let follower_id = ConnectionId(*follower_id);
1083            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1084                self.peer
1085                    .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1086            }
1087        }
1088        Ok(())
1089    }
1090
1091    async fn get_channels(
1092        self: Arc<Server>,
1093        request: TypedEnvelope<proto::GetChannels>,
1094        response: Response<proto::GetChannels>,
1095    ) -> Result<()> {
1096        let user_id = self
1097            .store()
1098            .await
1099            .user_id_for_connection(request.sender_id)?;
1100        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1101        response.send(proto::GetChannelsResponse {
1102            channels: channels
1103                .into_iter()
1104                .map(|chan| proto::Channel {
1105                    id: chan.id.to_proto(),
1106                    name: chan.name,
1107                })
1108                .collect(),
1109        })?;
1110        Ok(())
1111    }
1112
1113    async fn get_users(
1114        self: Arc<Server>,
1115        request: TypedEnvelope<proto::GetUsers>,
1116        response: Response<proto::GetUsers>,
1117    ) -> Result<()> {
1118        let user_ids = request
1119            .payload
1120            .user_ids
1121            .into_iter()
1122            .map(UserId::from_proto)
1123            .collect();
1124        let users = self
1125            .app_state
1126            .db
1127            .get_users_by_ids(user_ids)
1128            .await?
1129            .into_iter()
1130            .map(|user| proto::User {
1131                id: user.id.to_proto(),
1132                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1133                github_login: user.github_login,
1134            })
1135            .collect();
1136        response.send(proto::UsersResponse { users })?;
1137        Ok(())
1138    }
1139
1140    async fn fuzzy_search_users(
1141        self: Arc<Server>,
1142        request: TypedEnvelope<proto::FuzzySearchUsers>,
1143        response: Response<proto::FuzzySearchUsers>,
1144    ) -> Result<()> {
1145        let user_id = self
1146            .store()
1147            .await
1148            .user_id_for_connection(request.sender_id)?;
1149        let query = request.payload.query;
1150        let db = &self.app_state.db;
1151        let users = match query.len() {
1152            0 => vec![],
1153            1 | 2 => db
1154                .get_user_by_github_login(&query)
1155                .await?
1156                .into_iter()
1157                .collect(),
1158            _ => db.fuzzy_search_users(&query, 10).await?,
1159        };
1160        let users = users
1161            .into_iter()
1162            .filter(|user| user.id != user_id)
1163            .map(|user| proto::User {
1164                id: user.id.to_proto(),
1165                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1166                github_login: user.github_login,
1167            })
1168            .collect();
1169        response.send(proto::UsersResponse { users })?;
1170        Ok(())
1171    }
1172
1173    async fn request_contact(
1174        self: Arc<Server>,
1175        request: TypedEnvelope<proto::RequestContact>,
1176        response: Response<proto::RequestContact>,
1177    ) -> Result<()> {
1178        let requester_id = self
1179            .store()
1180            .await
1181            .user_id_for_connection(request.sender_id)?;
1182        let responder_id = UserId::from_proto(request.payload.responder_id);
1183        if requester_id == responder_id {
1184            return Err(anyhow!("cannot add yourself as a contact"))?;
1185        }
1186
1187        self.app_state
1188            .db
1189            .send_contact_request(requester_id, responder_id)
1190            .await?;
1191
1192        // Update outgoing contact requests of requester
1193        let mut update = proto::UpdateContacts::default();
1194        update.outgoing_requests.push(responder_id.to_proto());
1195        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1196            self.peer.send(connection_id, update.clone())?;
1197        }
1198
1199        // Update incoming contact requests of responder
1200        let mut update = proto::UpdateContacts::default();
1201        update
1202            .incoming_requests
1203            .push(proto::IncomingContactRequest {
1204                requester_id: requester_id.to_proto(),
1205                should_notify: true,
1206            });
1207        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1208            self.peer.send(connection_id, update.clone())?;
1209        }
1210
1211        response.send(proto::Ack {})?;
1212        Ok(())
1213    }
1214
1215    async fn respond_to_contact_request(
1216        self: Arc<Server>,
1217        request: TypedEnvelope<proto::RespondToContactRequest>,
1218        response: Response<proto::RespondToContactRequest>,
1219    ) -> Result<()> {
1220        let responder_id = self
1221            .store()
1222            .await
1223            .user_id_for_connection(request.sender_id)?;
1224        let requester_id = UserId::from_proto(request.payload.requester_id);
1225        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1226            self.app_state
1227                .db
1228                .dismiss_contact_notification(responder_id, requester_id)
1229                .await?;
1230        } else {
1231            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1232            self.app_state
1233                .db
1234                .respond_to_contact_request(responder_id, requester_id, accept)
1235                .await?;
1236
1237            let store = self.store().await;
1238            // Update responder with new contact
1239            let mut update = proto::UpdateContacts::default();
1240            if accept {
1241                update
1242                    .contacts
1243                    .push(store.contact_for_user(requester_id, false));
1244            }
1245            update
1246                .remove_incoming_requests
1247                .push(requester_id.to_proto());
1248            for connection_id in store.connection_ids_for_user(responder_id) {
1249                self.peer.send(connection_id, update.clone())?;
1250            }
1251
1252            // Update requester with new contact
1253            let mut update = proto::UpdateContacts::default();
1254            if accept {
1255                update
1256                    .contacts
1257                    .push(store.contact_for_user(responder_id, true));
1258            }
1259            update
1260                .remove_outgoing_requests
1261                .push(responder_id.to_proto());
1262            for connection_id in store.connection_ids_for_user(requester_id) {
1263                self.peer.send(connection_id, update.clone())?;
1264            }
1265        }
1266
1267        response.send(proto::Ack {})?;
1268        Ok(())
1269    }
1270
1271    async fn remove_contact(
1272        self: Arc<Server>,
1273        request: TypedEnvelope<proto::RemoveContact>,
1274        response: Response<proto::RemoveContact>,
1275    ) -> Result<()> {
1276        let requester_id = self
1277            .store()
1278            .await
1279            .user_id_for_connection(request.sender_id)?;
1280        let responder_id = UserId::from_proto(request.payload.user_id);
1281        self.app_state
1282            .db
1283            .remove_contact(requester_id, responder_id)
1284            .await?;
1285
1286        // Update outgoing contact requests of requester
1287        let mut update = proto::UpdateContacts::default();
1288        update
1289            .remove_outgoing_requests
1290            .push(responder_id.to_proto());
1291        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1292            self.peer.send(connection_id, update.clone())?;
1293        }
1294
1295        // Update incoming contact requests of responder
1296        let mut update = proto::UpdateContacts::default();
1297        update
1298            .remove_incoming_requests
1299            .push(requester_id.to_proto());
1300        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1301            self.peer.send(connection_id, update.clone())?;
1302        }
1303
1304        response.send(proto::Ack {})?;
1305        Ok(())
1306    }
1307
1308    async fn join_channel(
1309        self: Arc<Self>,
1310        request: TypedEnvelope<proto::JoinChannel>,
1311        response: Response<proto::JoinChannel>,
1312    ) -> Result<()> {
1313        let user_id = self
1314            .store()
1315            .await
1316            .user_id_for_connection(request.sender_id)?;
1317        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1318        if !self
1319            .app_state
1320            .db
1321            .can_user_access_channel(user_id, channel_id)
1322            .await?
1323        {
1324            Err(anyhow!("access denied"))?;
1325        }
1326
1327        self.store_mut()
1328            .await
1329            .join_channel(request.sender_id, channel_id);
1330        let messages = self
1331            .app_state
1332            .db
1333            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1334            .await?
1335            .into_iter()
1336            .map(|msg| proto::ChannelMessage {
1337                id: msg.id.to_proto(),
1338                body: msg.body,
1339                timestamp: msg.sent_at.unix_timestamp() as u64,
1340                sender_id: msg.sender_id.to_proto(),
1341                nonce: Some(msg.nonce.as_u128().into()),
1342            })
1343            .collect::<Vec<_>>();
1344        response.send(proto::JoinChannelResponse {
1345            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1346            messages,
1347        })?;
1348        Ok(())
1349    }
1350
1351    async fn leave_channel(
1352        self: Arc<Self>,
1353        request: TypedEnvelope<proto::LeaveChannel>,
1354    ) -> Result<()> {
1355        let user_id = self
1356            .store()
1357            .await
1358            .user_id_for_connection(request.sender_id)?;
1359        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1360        if !self
1361            .app_state
1362            .db
1363            .can_user_access_channel(user_id, channel_id)
1364            .await?
1365        {
1366            Err(anyhow!("access denied"))?;
1367        }
1368
1369        self.store_mut()
1370            .await
1371            .leave_channel(request.sender_id, channel_id);
1372
1373        Ok(())
1374    }
1375
1376    async fn send_channel_message(
1377        self: Arc<Self>,
1378        request: TypedEnvelope<proto::SendChannelMessage>,
1379        response: Response<proto::SendChannelMessage>,
1380    ) -> Result<()> {
1381        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1382        let user_id;
1383        let connection_ids;
1384        {
1385            let state = self.store().await;
1386            user_id = state.user_id_for_connection(request.sender_id)?;
1387            connection_ids = state.channel_connection_ids(channel_id)?;
1388        }
1389
1390        // Validate the message body.
1391        let body = request.payload.body.trim().to_string();
1392        if body.len() > MAX_MESSAGE_LEN {
1393            return Err(anyhow!("message is too long"))?;
1394        }
1395        if body.is_empty() {
1396            return Err(anyhow!("message can't be blank"))?;
1397        }
1398
1399        let timestamp = OffsetDateTime::now_utc();
1400        let nonce = request
1401            .payload
1402            .nonce
1403            .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1404
1405        let message_id = self
1406            .app_state
1407            .db
1408            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1409            .await?
1410            .to_proto();
1411        let message = proto::ChannelMessage {
1412            sender_id: user_id.to_proto(),
1413            id: message_id,
1414            body,
1415            timestamp: timestamp.unix_timestamp() as u64,
1416            nonce: Some(nonce),
1417        };
1418        broadcast(request.sender_id, connection_ids, |conn_id| {
1419            self.peer.send(
1420                conn_id,
1421                proto::ChannelMessageSent {
1422                    channel_id: channel_id.to_proto(),
1423                    message: Some(message.clone()),
1424                },
1425            )
1426        });
1427        response.send(proto::SendChannelMessageResponse {
1428            message: Some(message),
1429        })?;
1430        Ok(())
1431    }
1432
1433    async fn get_channel_messages(
1434        self: Arc<Self>,
1435        request: TypedEnvelope<proto::GetChannelMessages>,
1436        response: Response<proto::GetChannelMessages>,
1437    ) -> Result<()> {
1438        let user_id = self
1439            .store()
1440            .await
1441            .user_id_for_connection(request.sender_id)?;
1442        let channel_id = ChannelId::from_proto(request.payload.channel_id);
1443        if !self
1444            .app_state
1445            .db
1446            .can_user_access_channel(user_id, channel_id)
1447            .await?
1448        {
1449            Err(anyhow!("access denied"))?;
1450        }
1451
1452        let messages = self
1453            .app_state
1454            .db
1455            .get_channel_messages(
1456                channel_id,
1457                MESSAGE_COUNT_PER_PAGE,
1458                Some(MessageId::from_proto(request.payload.before_message_id)),
1459            )
1460            .await?
1461            .into_iter()
1462            .map(|msg| proto::ChannelMessage {
1463                id: msg.id.to_proto(),
1464                body: msg.body,
1465                timestamp: msg.sent_at.unix_timestamp() as u64,
1466                sender_id: msg.sender_id.to_proto(),
1467                nonce: Some(msg.nonce.as_u128().into()),
1468            })
1469            .collect::<Vec<_>>();
1470        response.send(proto::GetChannelMessagesResponse {
1471            done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1472            messages,
1473        })?;
1474        Ok(())
1475    }
1476
1477    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1478        #[cfg(test)]
1479        tokio::task::yield_now().await;
1480        let guard = self.store.read().await;
1481        #[cfg(test)]
1482        tokio::task::yield_now().await;
1483        StoreReadGuard {
1484            guard,
1485            _not_send: PhantomData,
1486        }
1487    }
1488
1489    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1490        #[cfg(test)]
1491        tokio::task::yield_now().await;
1492        let guard = self.store.write().await;
1493        #[cfg(test)]
1494        tokio::task::yield_now().await;
1495        StoreWriteGuard {
1496            guard,
1497            _not_send: PhantomData,
1498        }
1499    }
1500
1501    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1502        ServerSnapshot {
1503            store: self.store.read().await,
1504            peer: &self.peer,
1505        }
1506    }
1507}
1508
1509impl<'a> Deref for StoreReadGuard<'a> {
1510    type Target = Store;
1511
1512    fn deref(&self) -> &Self::Target {
1513        &*self.guard
1514    }
1515}
1516
1517impl<'a> Deref for StoreWriteGuard<'a> {
1518    type Target = Store;
1519
1520    fn deref(&self) -> &Self::Target {
1521        &*self.guard
1522    }
1523}
1524
1525impl<'a> DerefMut for StoreWriteGuard<'a> {
1526    fn deref_mut(&mut self) -> &mut Self::Target {
1527        &mut *self.guard
1528    }
1529}
1530
1531impl<'a> Drop for StoreWriteGuard<'a> {
1532    fn drop(&mut self) {
1533        #[cfg(test)]
1534        self.check_invariants();
1535
1536        let metrics = self.metrics();
1537        tracing::info!(
1538            connections = metrics.connections,
1539            registered_projects = metrics.registered_projects,
1540            shared_projects = metrics.shared_projects,
1541            "metrics"
1542        );
1543    }
1544}
1545
1546impl Executor for RealExecutor {
1547    type Sleep = Sleep;
1548
1549    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1550        tokio::task::spawn(future);
1551    }
1552
1553    fn sleep(&self, duration: Duration) -> Self::Sleep {
1554        tokio::time::sleep(duration)
1555    }
1556}
1557
1558fn broadcast<F>(
1559    sender_id: ConnectionId,
1560    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1561    mut f: F,
1562) where
1563    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1564{
1565    for receiver_id in receiver_ids {
1566        if receiver_id != sender_id {
1567            f(receiver_id).trace_err();
1568        }
1569    }
1570}
1571
1572lazy_static! {
1573    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1574}
1575
1576pub struct ProtocolVersion(u32);
1577
1578impl Header for ProtocolVersion {
1579    fn name() -> &'static HeaderName {
1580        &ZED_PROTOCOL_VERSION
1581    }
1582
1583    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1584    where
1585        Self: Sized,
1586        I: Iterator<Item = &'i axum::http::HeaderValue>,
1587    {
1588        let version = values
1589            .next()
1590            .ok_or_else(|| axum::headers::Error::invalid())?
1591            .to_str()
1592            .map_err(|_| axum::headers::Error::invalid())?
1593            .parse()
1594            .map_err(|_| axum::headers::Error::invalid())?;
1595        Ok(Self(version))
1596    }
1597
1598    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1599        values.extend([self.0.to_string().parse().unwrap()]);
1600    }
1601}
1602
1603pub fn routes(server: Arc<Server>) -> Router<Body> {
1604    Router::new()
1605        .route("/rpc", get(handle_websocket_request))
1606        .layer(
1607            ServiceBuilder::new()
1608                .layer(Extension(server.app_state.clone()))
1609                .layer(middleware::from_fn(auth::validate_header))
1610                .layer(Extension(server)),
1611        )
1612}
1613
1614pub async fn handle_websocket_request(
1615    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1616    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1617    Extension(server): Extension<Arc<Server>>,
1618    Extension(user): Extension<User>,
1619    ws: WebSocketUpgrade,
1620) -> axum::response::Response {
1621    if protocol_version != rpc::PROTOCOL_VERSION {
1622        return (
1623            StatusCode::UPGRADE_REQUIRED,
1624            "client must be upgraded".to_string(),
1625        )
1626            .into_response();
1627    }
1628    let socket_address = socket_address.to_string();
1629    ws.on_upgrade(move |socket| {
1630        use util::ResultExt;
1631        let socket = socket
1632            .map_ok(to_tungstenite_message)
1633            .err_into()
1634            .with(|message| async move { Ok(to_axum_message(message)) });
1635        let connection = Connection::new(Box::pin(socket));
1636        async move {
1637            server
1638                .handle_connection(connection, socket_address, user, None, RealExecutor)
1639                .await
1640                .log_err();
1641        }
1642    })
1643}
1644
1645fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1646    match message {
1647        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1648        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1649        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1650        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1651        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1652            code: frame.code.into(),
1653            reason: frame.reason,
1654        })),
1655    }
1656}
1657
1658fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1659    match message {
1660        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1661        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1662        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1663        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1664        AxumMessage::Close(frame) => {
1665            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1666                code: frame.code.into(),
1667                reason: frame.reason,
1668            }))
1669        }
1670    }
1671}
1672
1673pub trait ResultExt {
1674    type Ok;
1675
1676    fn trace_err(self) -> Option<Self::Ok>;
1677}
1678
1679impl<T, E> ResultExt for Result<T, E>
1680where
1681    E: std::fmt::Debug,
1682{
1683    type Ok = T;
1684
1685    fn trace_err(self) -> Option<T> {
1686        match self {
1687            Ok(value) => Some(value),
1688            Err(error) => {
1689                tracing::error!("{:?}", error);
1690                None
1691            }
1692        }
1693    }
1694}