rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52pub use store::{Store, Worktree};
  53use tokio::{
  54    sync::{Mutex, MutexGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60lazy_static! {
  61    static ref METRIC_CONNECTIONS: IntGauge =
  62        register_int_gauge!("connections", "number of connections").unwrap();
  63    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  64        "shared_projects",
  65        "number of open projects with one or more guests"
  66    )
  67    .unwrap();
  68}
  69
  70type MessageHandler = Box<
  71    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  72>;
  73
  74struct Message<T> {
  75    sender_user_id: UserId,
  76    sender_connection_id: ConnectionId,
  77    payload: T,
  78}
  79
  80struct Response<R> {
  81    server: Arc<Server>,
  82    receipt: Receipt<R>,
  83    responded: Arc<AtomicBool>,
  84}
  85
  86impl<R: RequestMessage> Response<R> {
  87    fn send(self, payload: R::Response) -> Result<()> {
  88        self.responded.store(true, SeqCst);
  89        self.server.peer.respond(self.receipt, payload)?;
  90        Ok(())
  91    }
  92}
  93
  94pub struct Server {
  95    peer: Arc<Peer>,
  96    pub(crate) store: Mutex<Store>,
  97    app_state: Arc<AppState>,
  98    handlers: HashMap<TypeId, MessageHandler>,
  99}
 100
 101pub trait Executor: Send + Clone {
 102    type Sleep: Send + Future;
 103    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 104    fn sleep(&self, duration: Duration) -> Self::Sleep;
 105}
 106
 107#[derive(Clone)]
 108pub struct RealExecutor;
 109
 110pub(crate) struct StoreGuard<'a> {
 111    guard: MutexGuard<'a, Store>,
 112    _not_send: PhantomData<Rc<()>>,
 113}
 114
 115#[derive(Serialize)]
 116pub struct ServerSnapshot<'a> {
 117    peer: &'a Peer,
 118    #[serde(serialize_with = "serialize_deref")]
 119    store: StoreGuard<'a>,
 120}
 121
 122pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 123where
 124    S: Serializer,
 125    T: Deref<Target = U>,
 126    U: Serialize,
 127{
 128    Serialize::serialize(value.deref(), serializer)
 129}
 130
 131impl Server {
 132    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138        };
 139
 140        server
 141            .add_request_handler(Server::ping)
 142            .add_request_handler(Server::create_room)
 143            .add_request_handler(Server::join_room)
 144            .add_message_handler(Server::leave_room)
 145            .add_request_handler(Server::call)
 146            .add_request_handler(Server::cancel_call)
 147            .add_message_handler(Server::decline_call)
 148            .add_request_handler(Server::update_participant_location)
 149            .add_request_handler(Server::share_project)
 150            .add_message_handler(Server::unshare_project)
 151            .add_request_handler(Server::join_project)
 152            .add_message_handler(Server::leave_project)
 153            .add_request_handler(Server::update_project)
 154            .add_request_handler(Server::update_worktree)
 155            .add_message_handler(Server::start_language_server)
 156            .add_message_handler(Server::update_language_server)
 157            .add_request_handler(Server::update_diagnostic_summary)
 158            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 159            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 162            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 163            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 165            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 169            .add_request_handler(
 170                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 171            )
 172            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 173            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 174            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 175            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 177            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 179            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 182            .add_message_handler(Server::create_buffer_for_peer)
 183            .add_request_handler(Server::update_buffer)
 184            .add_message_handler(Server::update_buffer_file)
 185            .add_message_handler(Server::buffer_reloaded)
 186            .add_message_handler(Server::buffer_saved)
 187            .add_request_handler(Server::save_buffer)
 188            .add_request_handler(Server::get_users)
 189            .add_request_handler(Server::fuzzy_search_users)
 190            .add_request_handler(Server::request_contact)
 191            .add_request_handler(Server::remove_contact)
 192            .add_request_handler(Server::respond_to_contact_request)
 193            .add_request_handler(Server::follow)
 194            .add_message_handler(Server::unfollow)
 195            .add_message_handler(Server::update_followers)
 196            .add_message_handler(Server::update_diff_base)
 197            .add_request_handler(Server::get_private_user_info);
 198
 199        Arc::new(server)
 200    }
 201
 202    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 203    where
 204        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 205        Fut: 'static + Send + Future<Output = Result<()>>,
 206        M: EnvelopedMessage,
 207    {
 208        let prev_handler = self.handlers.insert(
 209            TypeId::of::<M>(),
 210            Box::new(move |server, sender_user_id, envelope| {
 211                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 212                let span = info_span!(
 213                    "handle message",
 214                    payload_type = envelope.payload_type_name()
 215                );
 216                span.in_scope(|| {
 217                    tracing::info!(
 218                        payload_type = envelope.payload_type_name(),
 219                        "message received"
 220                    );
 221                });
 222                let future = (handler)(server, sender_user_id, *envelope);
 223                async move {
 224                    if let Err(error) = future.await {
 225                        tracing::error!(%error, "error handling message");
 226                    }
 227                }
 228                .instrument(span)
 229                .boxed()
 230            }),
 231        );
 232        if prev_handler.is_some() {
 233            panic!("registered a handler for the same message twice");
 234        }
 235        self
 236    }
 237
 238    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 239    where
 240        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 241        Fut: 'static + Send + Future<Output = Result<()>>,
 242        M: EnvelopedMessage,
 243    {
 244        self.add_handler(move |server, sender_user_id, envelope| {
 245            handler(
 246                server,
 247                Message {
 248                    sender_user_id,
 249                    sender_connection_id: envelope.sender_id,
 250                    payload: envelope.payload,
 251                },
 252            )
 253        });
 254        self
 255    }
 256
 257    /// Handle a request while holding a lock to the store. This is useful when we're registering
 258    /// a connection but we want to respond on the connection before anybody else can send on it.
 259    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 260    where
 261        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 262        Fut: Send + Future<Output = Result<()>>,
 263        M: RequestMessage,
 264    {
 265        let handler = Arc::new(handler);
 266        self.add_handler(move |server, sender_user_id, envelope| {
 267            let receipt = envelope.receipt();
 268            let handler = handler.clone();
 269            async move {
 270                let request = Message {
 271                    sender_user_id,
 272                    sender_connection_id: envelope.sender_id,
 273                    payload: envelope.payload,
 274                };
 275                let responded = Arc::new(AtomicBool::default());
 276                let response = Response {
 277                    server: server.clone(),
 278                    responded: responded.clone(),
 279                    receipt,
 280                };
 281                match (handler)(server.clone(), request, response).await {
 282                    Ok(()) => {
 283                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 284                            Ok(())
 285                        } else {
 286                            Err(anyhow!("handler did not send a response"))?
 287                        }
 288                    }
 289                    Err(error) => {
 290                        server.peer.respond_with_error(
 291                            receipt,
 292                            proto::Error {
 293                                message: error.to_string(),
 294                            },
 295                        )?;
 296                        Err(error)
 297                    }
 298                }
 299            }
 300        })
 301    }
 302
 303    pub fn handle_connection<E: Executor>(
 304        self: &Arc<Self>,
 305        connection: Connection,
 306        address: String,
 307        user: User,
 308        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 309        executor: E,
 310    ) -> impl Future<Output = Result<()>> {
 311        let mut this = self.clone();
 312        let user_id = user.id;
 313        let login = user.github_login;
 314        let span = info_span!("handle connection", %user_id, %login, %address);
 315        async move {
 316            let (connection_id, handle_io, mut incoming_rx) = this
 317                .peer
 318                .add_connection(connection, {
 319                    let executor = executor.clone();
 320                    move |duration| {
 321                        let timer = executor.sleep(duration);
 322                        async move {
 323                            timer.await;
 324                        }
 325                    }
 326                });
 327
 328            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 329            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 330            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 331
 332            if let Some(send_connection_id) = send_connection_id.take() {
 333                let _ = send_connection_id.send(connection_id);
 334            }
 335
 336            if !user.connected_once {
 337                this.peer.send(connection_id, proto::ShowContacts {})?;
 338                this.app_state.db.set_user_connected_once(user_id, true).await?;
 339            }
 340
 341            let (contacts, invite_code) = future::try_join(
 342                this.app_state.db.get_contacts(user_id),
 343                this.app_state.db.get_invite_code_for_user(user_id)
 344            ).await?;
 345
 346            {
 347                let mut store = this.store().await;
 348                store.add_connection(connection_id, user_id, user.admin);
 349                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 350
 351                if let Some((code, count)) = invite_code {
 352                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 353                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 354                        count,
 355                    })?;
 356                }
 357            }
 358
 359            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 360                this.peer.send(connection_id, incoming_call)?;
 361            }
 362
 363            this.update_user_contacts(user_id).await?;
 364
 365            let handle_io = handle_io.fuse();
 366            futures::pin_mut!(handle_io);
 367
 368            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 369            // This prevents deadlocks when e.g., client A performs a request to client B and
 370            // client B performs a request to client A. If both clients stop processing further
 371            // messages until their respective request completes, they won't have a chance to
 372            // respond to the other client's request and cause a deadlock.
 373            //
 374            // This arrangement ensures we will attempt to process earlier messages first, but fall
 375            // back to processing messages arrived later in the spirit of making progress.
 376            let mut foreground_message_handlers = FuturesUnordered::new();
 377            loop {
 378                let next_message = incoming_rx.next().fuse();
 379                futures::pin_mut!(next_message);
 380                futures::select_biased! {
 381                    result = handle_io => {
 382                        if let Err(error) = result {
 383                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 384                        }
 385                        break;
 386                    }
 387                    _ = foreground_message_handlers.next() => {}
 388                    message = next_message => {
 389                        if let Some(message) = message {
 390                            let type_name = message.payload_type_name();
 391                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 392                            let span_enter = span.enter();
 393                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 394                                let is_background = message.is_background();
 395                                let handle_message = (handler)(this.clone(), user_id, message);
 396                                drop(span_enter);
 397
 398                                let handle_message = handle_message.instrument(span);
 399                                if is_background {
 400                                    executor.spawn_detached(handle_message);
 401                                } else {
 402                                    foreground_message_handlers.push(handle_message);
 403                                }
 404                            } else {
 405                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 406                            }
 407                        } else {
 408                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 409                            break;
 410                        }
 411                    }
 412                }
 413            }
 414
 415            drop(foreground_message_handlers);
 416            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 417            if let Err(error) = this.sign_out(connection_id, user_id).await {
 418                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 419            }
 420
 421            Ok(())
 422        }.instrument(span)
 423    }
 424
 425    #[instrument(skip(self), err)]
 426    async fn sign_out(
 427        self: &mut Arc<Self>,
 428        connection_id: ConnectionId,
 429        user_id: UserId,
 430    ) -> Result<()> {
 431        self.peer.disconnect(connection_id);
 432        let decline_calls = {
 433            let mut store = self.store().await;
 434            store.remove_connection(connection_id)?;
 435            let mut connections = store.connection_ids_for_user(user_id);
 436            connections.next().is_none()
 437        };
 438
 439        self.leave_room_for_connection(connection_id, user_id)
 440            .await
 441            .trace_err();
 442        if decline_calls {
 443            if let Some(room) = self
 444                .app_state
 445                .db
 446                .decline_call(None, user_id)
 447                .await
 448                .trace_err()
 449            {
 450                self.room_updated(&room);
 451            }
 452        }
 453
 454        self.update_user_contacts(user_id).await?;
 455
 456        Ok(())
 457    }
 458
 459    pub async fn invite_code_redeemed(
 460        self: &Arc<Self>,
 461        inviter_id: UserId,
 462        invitee_id: UserId,
 463    ) -> Result<()> {
 464        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 465            if let Some(code) = &user.invite_code {
 466                let store = self.store().await;
 467                let invitee_contact = store.contact_for_user(invitee_id, true, false);
 468                for connection_id in store.connection_ids_for_user(inviter_id) {
 469                    self.peer.send(
 470                        connection_id,
 471                        proto::UpdateContacts {
 472                            contacts: vec![invitee_contact.clone()],
 473                            ..Default::default()
 474                        },
 475                    )?;
 476                    self.peer.send(
 477                        connection_id,
 478                        proto::UpdateInviteInfo {
 479                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 480                            count: user.invite_count as u32,
 481                        },
 482                    )?;
 483                }
 484            }
 485        }
 486        Ok(())
 487    }
 488
 489    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 490        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 491            if let Some(invite_code) = &user.invite_code {
 492                let store = self.store().await;
 493                for connection_id in store.connection_ids_for_user(user_id) {
 494                    self.peer.send(
 495                        connection_id,
 496                        proto::UpdateInviteInfo {
 497                            url: format!(
 498                                "{}{}",
 499                                self.app_state.config.invite_link_prefix, invite_code
 500                            ),
 501                            count: user.invite_count as u32,
 502                        },
 503                    )?;
 504                }
 505            }
 506        }
 507        Ok(())
 508    }
 509
 510    async fn ping(
 511        self: Arc<Server>,
 512        _: Message<proto::Ping>,
 513        response: Response<proto::Ping>,
 514    ) -> Result<()> {
 515        response.send(proto::Ack {})?;
 516        Ok(())
 517    }
 518
 519    async fn create_room(
 520        self: Arc<Server>,
 521        request: Message<proto::CreateRoom>,
 522        response: Response<proto::CreateRoom>,
 523    ) -> Result<()> {
 524        let room = self
 525            .app_state
 526            .db
 527            .create_room(request.sender_user_id, request.sender_connection_id)
 528            .await?;
 529
 530        let live_kit_connection_info =
 531            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 532                if let Some(_) = live_kit
 533                    .create_room(room.live_kit_room.clone())
 534                    .await
 535                    .trace_err()
 536                {
 537                    if let Some(token) = live_kit
 538                        .room_token(
 539                            &room.live_kit_room,
 540                            &request.sender_connection_id.to_string(),
 541                        )
 542                        .trace_err()
 543                    {
 544                        Some(proto::LiveKitConnectionInfo {
 545                            server_url: live_kit.url().into(),
 546                            token,
 547                        })
 548                    } else {
 549                        None
 550                    }
 551                } else {
 552                    None
 553                }
 554            } else {
 555                None
 556            };
 557
 558        response.send(proto::CreateRoomResponse {
 559            room: Some(room),
 560            live_kit_connection_info,
 561        })?;
 562        self.update_user_contacts(request.sender_user_id).await?;
 563        Ok(())
 564    }
 565
 566    async fn join_room(
 567        self: Arc<Server>,
 568        request: Message<proto::JoinRoom>,
 569        response: Response<proto::JoinRoom>,
 570    ) -> Result<()> {
 571        let room = self
 572            .app_state
 573            .db
 574            .join_room(
 575                RoomId::from_proto(request.payload.id),
 576                request.sender_user_id,
 577                request.sender_connection_id,
 578            )
 579            .await?;
 580        for connection_id in self
 581            .store()
 582            .await
 583            .connection_ids_for_user(request.sender_user_id)
 584        {
 585            self.peer
 586                .send(connection_id, proto::CallCanceled {})
 587                .trace_err();
 588        }
 589
 590        let live_kit_connection_info =
 591            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 592                if let Some(token) = live_kit
 593                    .room_token(
 594                        &room.live_kit_room,
 595                        &request.sender_connection_id.to_string(),
 596                    )
 597                    .trace_err()
 598                {
 599                    Some(proto::LiveKitConnectionInfo {
 600                        server_url: live_kit.url().into(),
 601                        token,
 602                    })
 603                } else {
 604                    None
 605                }
 606            } else {
 607                None
 608            };
 609
 610        self.room_updated(&room);
 611        response.send(proto::JoinRoomResponse {
 612            room: Some(room),
 613            live_kit_connection_info,
 614        })?;
 615
 616        self.update_user_contacts(request.sender_user_id).await?;
 617        Ok(())
 618    }
 619
 620    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 621        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
 622            .await
 623    }
 624
 625    async fn leave_room_for_connection(
 626        self: &Arc<Server>,
 627        leaving_connection_id: ConnectionId,
 628        leaving_user_id: UserId,
 629    ) -> Result<()> {
 630        let mut contacts_to_update = HashSet::default();
 631
 632        let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else {
 633            return Err(anyhow!("no room to leave"))?;
 634        };
 635        contacts_to_update.insert(leaving_user_id);
 636
 637        for project in left_room.left_projects.into_values() {
 638            for connection_id in project.connection_ids {
 639                if project.host_user_id == leaving_user_id {
 640                    self.peer
 641                        .send(
 642                            connection_id,
 643                            proto::UnshareProject {
 644                                project_id: project.id.to_proto(),
 645                            },
 646                        )
 647                        .trace_err();
 648                } else {
 649                    self.peer
 650                        .send(
 651                            connection_id,
 652                            proto::RemoveProjectCollaborator {
 653                                project_id: project.id.to_proto(),
 654                                peer_id: leaving_connection_id.0,
 655                            },
 656                        )
 657                        .trace_err();
 658                }
 659            }
 660
 661            self.peer
 662                .send(
 663                    leaving_connection_id,
 664                    proto::UnshareProject {
 665                        project_id: project.id.to_proto(),
 666                    },
 667                )
 668                .trace_err();
 669        }
 670
 671        self.room_updated(&left_room.room);
 672        {
 673            let store = self.store().await;
 674            for canceled_user_id in left_room.canceled_calls_to_user_ids {
 675                for connection_id in store.connection_ids_for_user(canceled_user_id) {
 676                    self.peer
 677                        .send(connection_id, proto::CallCanceled {})
 678                        .trace_err();
 679                }
 680                contacts_to_update.insert(canceled_user_id);
 681            }
 682        }
 683
 684        for contact_user_id in contacts_to_update {
 685            self.update_user_contacts(contact_user_id).await?;
 686        }
 687
 688        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 689            live_kit
 690                .remove_participant(
 691                    left_room.room.live_kit_room.clone(),
 692                    leaving_connection_id.to_string(),
 693                )
 694                .await
 695                .trace_err();
 696
 697            if left_room.room.participants.is_empty() {
 698                live_kit
 699                    .delete_room(left_room.room.live_kit_room)
 700                    .await
 701                    .trace_err();
 702            }
 703        }
 704
 705        Ok(())
 706    }
 707
 708    async fn call(
 709        self: Arc<Server>,
 710        request: Message<proto::Call>,
 711        response: Response<proto::Call>,
 712    ) -> Result<()> {
 713        let room_id = RoomId::from_proto(request.payload.room_id);
 714        let calling_user_id = request.sender_user_id;
 715        let calling_connection_id = request.sender_connection_id;
 716        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 717        let initial_project_id = request
 718            .payload
 719            .initial_project_id
 720            .map(ProjectId::from_proto);
 721        if !self
 722            .app_state
 723            .db
 724            .has_contact(calling_user_id, called_user_id)
 725            .await?
 726        {
 727            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 728        }
 729
 730        let (room, incoming_call) = self
 731            .app_state
 732            .db
 733            .call(
 734                room_id,
 735                calling_user_id,
 736                calling_connection_id,
 737                called_user_id,
 738                initial_project_id,
 739            )
 740            .await?;
 741        self.room_updated(&room);
 742        self.update_user_contacts(called_user_id).await?;
 743
 744        let mut calls = self
 745            .store()
 746            .await
 747            .connection_ids_for_user(called_user_id)
 748            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 749            .collect::<FuturesUnordered<_>>();
 750
 751        while let Some(call_response) = calls.next().await {
 752            match call_response.as_ref() {
 753                Ok(_) => {
 754                    response.send(proto::Ack {})?;
 755                    return Ok(());
 756                }
 757                Err(_) => {
 758                    call_response.trace_err();
 759                }
 760            }
 761        }
 762
 763        let room = self
 764            .app_state
 765            .db
 766            .call_failed(room_id, called_user_id)
 767            .await?;
 768        self.room_updated(&room);
 769        self.update_user_contacts(called_user_id).await?;
 770
 771        Err(anyhow!("failed to ring user"))?
 772    }
 773
 774    async fn cancel_call(
 775        self: Arc<Server>,
 776        request: Message<proto::CancelCall>,
 777        response: Response<proto::CancelCall>,
 778    ) -> Result<()> {
 779        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 780        let room_id = RoomId::from_proto(request.payload.room_id);
 781        let room = self
 782            .app_state
 783            .db
 784            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
 785            .await?;
 786        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
 787            self.peer
 788                .send(connection_id, proto::CallCanceled {})
 789                .trace_err();
 790        }
 791        self.room_updated(&room);
 792        response.send(proto::Ack {})?;
 793
 794        self.update_user_contacts(called_user_id).await?;
 795        Ok(())
 796    }
 797
 798    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 799        let room_id = RoomId::from_proto(message.payload.room_id);
 800        let room = self
 801            .app_state
 802            .db
 803            .decline_call(Some(room_id), message.sender_user_id)
 804            .await?;
 805        for connection_id in self
 806            .store()
 807            .await
 808            .connection_ids_for_user(message.sender_user_id)
 809        {
 810            self.peer
 811                .send(connection_id, proto::CallCanceled {})
 812                .trace_err();
 813        }
 814        self.room_updated(&room);
 815        self.update_user_contacts(message.sender_user_id).await?;
 816        Ok(())
 817    }
 818
 819    async fn update_participant_location(
 820        self: Arc<Server>,
 821        request: Message<proto::UpdateParticipantLocation>,
 822        response: Response<proto::UpdateParticipantLocation>,
 823    ) -> Result<()> {
 824        let room_id = RoomId::from_proto(request.payload.room_id);
 825        let location = request
 826            .payload
 827            .location
 828            .ok_or_else(|| anyhow!("invalid location"))?;
 829        let room = self
 830            .app_state
 831            .db
 832            .update_room_participant_location(room_id, request.sender_connection_id, location)
 833            .await?;
 834        self.room_updated(&room);
 835        response.send(proto::Ack {})?;
 836        Ok(())
 837    }
 838
 839    fn room_updated(&self, room: &proto::Room) {
 840        for participant in &room.participants {
 841            self.peer
 842                .send(
 843                    ConnectionId(participant.peer_id),
 844                    proto::RoomUpdated {
 845                        room: Some(room.clone()),
 846                    },
 847                )
 848                .trace_err();
 849        }
 850    }
 851
 852    async fn share_project(
 853        self: Arc<Server>,
 854        request: Message<proto::ShareProject>,
 855        response: Response<proto::ShareProject>,
 856    ) -> Result<()> {
 857        let (project_id, room) = self
 858            .app_state
 859            .db
 860            .share_project(
 861                RoomId::from_proto(request.payload.room_id),
 862                request.sender_connection_id,
 863                &request.payload.worktrees,
 864            )
 865            .await
 866            .unwrap();
 867        response.send(proto::ShareProjectResponse {
 868            project_id: project_id.to_proto(),
 869        })?;
 870        self.room_updated(&room);
 871
 872        Ok(())
 873    }
 874
 875    async fn unshare_project(
 876        self: Arc<Server>,
 877        message: Message<proto::UnshareProject>,
 878    ) -> Result<()> {
 879        let project_id = ProjectId::from_proto(message.payload.project_id);
 880
 881        let (room, guest_connection_ids) = self
 882            .app_state
 883            .db
 884            .unshare_project(project_id, message.sender_connection_id)
 885            .await?;
 886
 887        broadcast(
 888            message.sender_connection_id,
 889            guest_connection_ids,
 890            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 891        );
 892        self.room_updated(&room);
 893
 894        Ok(())
 895    }
 896
 897    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 898        let contacts = self.app_state.db.get_contacts(user_id).await?;
 899        let busy = self.app_state.db.is_user_busy(user_id).await?;
 900        let store = self.store().await;
 901        let updated_contact = store.contact_for_user(user_id, false, busy);
 902        for contact in contacts {
 903            if let db::Contact::Accepted {
 904                user_id: contact_user_id,
 905                ..
 906            } = contact
 907            {
 908                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 909                    self.peer
 910                        .send(
 911                            contact_conn_id,
 912                            proto::UpdateContacts {
 913                                contacts: vec![updated_contact.clone()],
 914                                remove_contacts: Default::default(),
 915                                incoming_requests: Default::default(),
 916                                remove_incoming_requests: Default::default(),
 917                                outgoing_requests: Default::default(),
 918                                remove_outgoing_requests: Default::default(),
 919                            },
 920                        )
 921                        .trace_err();
 922                }
 923            }
 924        }
 925        Ok(())
 926    }
 927
 928    async fn join_project(
 929        self: Arc<Server>,
 930        request: Message<proto::JoinProject>,
 931        response: Response<proto::JoinProject>,
 932    ) -> Result<()> {
 933        let project_id = ProjectId::from_proto(request.payload.project_id);
 934        let guest_user_id = request.sender_user_id;
 935
 936        tracing::info!(%project_id, "join project");
 937
 938        let (project, replica_id) = self
 939            .app_state
 940            .db
 941            .join_project(project_id, request.sender_connection_id)
 942            .await?;
 943
 944        let collaborators = project
 945            .collaborators
 946            .iter()
 947            .filter(|collaborator| {
 948                collaborator.connection_id != request.sender_connection_id.0 as i32
 949            })
 950            .map(|collaborator| proto::Collaborator {
 951                peer_id: collaborator.connection_id as u32,
 952                replica_id: collaborator.replica_id.0 as u32,
 953                user_id: collaborator.user_id.to_proto(),
 954            })
 955            .collect::<Vec<_>>();
 956        let worktrees = project
 957            .worktrees
 958            .iter()
 959            .map(|(id, worktree)| proto::WorktreeMetadata {
 960                id: id.to_proto(),
 961                root_name: worktree.root_name.clone(),
 962                visible: worktree.visible,
 963                abs_path: worktree.abs_path.clone(),
 964            })
 965            .collect::<Vec<_>>();
 966
 967        for collaborator in &collaborators {
 968            self.peer
 969                .send(
 970                    ConnectionId(collaborator.peer_id),
 971                    proto::AddProjectCollaborator {
 972                        project_id: project_id.to_proto(),
 973                        collaborator: Some(proto::Collaborator {
 974                            peer_id: request.sender_connection_id.0,
 975                            replica_id: replica_id.0 as u32,
 976                            user_id: guest_user_id.to_proto(),
 977                        }),
 978                    },
 979                )
 980                .trace_err();
 981        }
 982
 983        // First, we send the metadata associated with each worktree.
 984        response.send(proto::JoinProjectResponse {
 985            worktrees: worktrees.clone(),
 986            replica_id: replica_id.0 as u32,
 987            collaborators: collaborators.clone(),
 988            language_servers: project.language_servers.clone(),
 989        })?;
 990
 991        for (worktree_id, worktree) in project.worktrees {
 992            #[cfg(any(test, feature = "test-support"))]
 993            const MAX_CHUNK_SIZE: usize = 2;
 994            #[cfg(not(any(test, feature = "test-support")))]
 995            const MAX_CHUNK_SIZE: usize = 256;
 996
 997            // Stream this worktree's entries.
 998            let message = proto::UpdateWorktree {
 999                project_id: project_id.to_proto(),
1000                worktree_id: worktree_id.to_proto(),
1001                abs_path: worktree.abs_path.clone(),
1002                root_name: worktree.root_name,
1003                updated_entries: worktree.entries,
1004                removed_entries: Default::default(),
1005                scan_id: worktree.scan_id,
1006                is_last_update: worktree.is_complete,
1007            };
1008            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1009                self.peer
1010                    .send(request.sender_connection_id, update.clone())?;
1011            }
1012
1013            // Stream this worktree's diagnostics.
1014            for summary in worktree.diagnostic_summaries {
1015                self.peer.send(
1016                    request.sender_connection_id,
1017                    proto::UpdateDiagnosticSummary {
1018                        project_id: project_id.to_proto(),
1019                        worktree_id: worktree.id.to_proto(),
1020                        summary: Some(summary),
1021                    },
1022                )?;
1023            }
1024        }
1025
1026        for language_server in &project.language_servers {
1027            self.peer.send(
1028                request.sender_connection_id,
1029                proto::UpdateLanguageServer {
1030                    project_id: project_id.to_proto(),
1031                    language_server_id: language_server.id,
1032                    variant: Some(
1033                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1034                            proto::LspDiskBasedDiagnosticsUpdated {},
1035                        ),
1036                    ),
1037                },
1038            )?;
1039        }
1040
1041        Ok(())
1042    }
1043
1044    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1045        let sender_id = request.sender_connection_id;
1046        let project_id = ProjectId::from_proto(request.payload.project_id);
1047        let project;
1048        {
1049            project = self
1050                .app_state
1051                .db
1052                .leave_project(project_id, sender_id)
1053                .await?;
1054            tracing::info!(
1055                %project_id,
1056                host_user_id = %project.host_user_id,
1057                host_connection_id = %project.host_connection_id,
1058                "leave project"
1059            );
1060
1061            broadcast(sender_id, project.connection_ids, |conn_id| {
1062                self.peer.send(
1063                    conn_id,
1064                    proto::RemoveProjectCollaborator {
1065                        project_id: project_id.to_proto(),
1066                        peer_id: sender_id.0,
1067                    },
1068                )
1069            });
1070        }
1071
1072        Ok(())
1073    }
1074
1075    async fn update_project(
1076        self: Arc<Server>,
1077        request: Message<proto::UpdateProject>,
1078        response: Response<proto::UpdateProject>,
1079    ) -> Result<()> {
1080        let project_id = ProjectId::from_proto(request.payload.project_id);
1081        let (room, guest_connection_ids) = self
1082            .app_state
1083            .db
1084            .update_project(
1085                project_id,
1086                request.sender_connection_id,
1087                &request.payload.worktrees,
1088            )
1089            .await?;
1090        broadcast(
1091            request.sender_connection_id,
1092            guest_connection_ids,
1093            |connection_id| {
1094                self.peer.forward_send(
1095                    request.sender_connection_id,
1096                    connection_id,
1097                    request.payload.clone(),
1098                )
1099            },
1100        );
1101        self.room_updated(&room);
1102        response.send(proto::Ack {})?;
1103
1104        Ok(())
1105    }
1106
1107    async fn update_worktree(
1108        self: Arc<Server>,
1109        request: Message<proto::UpdateWorktree>,
1110        response: Response<proto::UpdateWorktree>,
1111    ) -> Result<()> {
1112        let guest_connection_ids = self
1113            .app_state
1114            .db
1115            .update_worktree(&request.payload, request.sender_connection_id)
1116            .await?;
1117
1118        broadcast(
1119            request.sender_connection_id,
1120            guest_connection_ids,
1121            |connection_id| {
1122                self.peer.forward_send(
1123                    request.sender_connection_id,
1124                    connection_id,
1125                    request.payload.clone(),
1126                )
1127            },
1128        );
1129        response.send(proto::Ack {})?;
1130        Ok(())
1131    }
1132
1133    async fn update_diagnostic_summary(
1134        self: Arc<Server>,
1135        request: Message<proto::UpdateDiagnosticSummary>,
1136        response: Response<proto::UpdateDiagnosticSummary>,
1137    ) -> Result<()> {
1138        let guest_connection_ids = self
1139            .app_state
1140            .db
1141            .update_diagnostic_summary(&request.payload, request.sender_connection_id)
1142            .await?;
1143
1144        broadcast(
1145            request.sender_connection_id,
1146            guest_connection_ids,
1147            |connection_id| {
1148                self.peer.forward_send(
1149                    request.sender_connection_id,
1150                    connection_id,
1151                    request.payload.clone(),
1152                )
1153            },
1154        );
1155
1156        response.send(proto::Ack {})?;
1157        Ok(())
1158    }
1159
1160    async fn start_language_server(
1161        self: Arc<Server>,
1162        request: Message<proto::StartLanguageServer>,
1163    ) -> Result<()> {
1164        let guest_connection_ids = self
1165            .app_state
1166            .db
1167            .start_language_server(&request.payload, request.sender_connection_id)
1168            .await?;
1169
1170        broadcast(
1171            request.sender_connection_id,
1172            guest_connection_ids,
1173            |connection_id| {
1174                self.peer.forward_send(
1175                    request.sender_connection_id,
1176                    connection_id,
1177                    request.payload.clone(),
1178                )
1179            },
1180        );
1181        Ok(())
1182    }
1183
1184    async fn update_language_server(
1185        self: Arc<Server>,
1186        request: Message<proto::UpdateLanguageServer>,
1187    ) -> Result<()> {
1188        let project_id = ProjectId::from_proto(request.payload.project_id);
1189        let project_connection_ids = self
1190            .app_state
1191            .db
1192            .project_connection_ids(project_id, request.sender_connection_id)
1193            .await?;
1194        broadcast(
1195            request.sender_connection_id,
1196            project_connection_ids,
1197            |connection_id| {
1198                self.peer.forward_send(
1199                    request.sender_connection_id,
1200                    connection_id,
1201                    request.payload.clone(),
1202                )
1203            },
1204        );
1205        Ok(())
1206    }
1207
1208    async fn forward_project_request<T>(
1209        self: Arc<Server>,
1210        request: Message<T>,
1211        response: Response<T>,
1212    ) -> Result<()>
1213    where
1214        T: EntityMessage + RequestMessage,
1215    {
1216        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1217        let collaborators = self
1218            .app_state
1219            .db
1220            .project_collaborators(project_id, request.sender_connection_id)
1221            .await?;
1222        let host = collaborators
1223            .iter()
1224            .find(|collaborator| collaborator.is_host)
1225            .ok_or_else(|| anyhow!("host not found"))?;
1226
1227        let payload = self
1228            .peer
1229            .forward_request(
1230                request.sender_connection_id,
1231                ConnectionId(host.connection_id as u32),
1232                request.payload,
1233            )
1234            .await?;
1235
1236        response.send(payload)?;
1237        Ok(())
1238    }
1239
1240    async fn save_buffer(
1241        self: Arc<Server>,
1242        request: Message<proto::SaveBuffer>,
1243        response: Response<proto::SaveBuffer>,
1244    ) -> Result<()> {
1245        let project_id = ProjectId::from_proto(request.payload.project_id);
1246        let collaborators = self
1247            .app_state
1248            .db
1249            .project_collaborators(project_id, request.sender_connection_id)
1250            .await?;
1251        let host = collaborators
1252            .into_iter()
1253            .find(|collaborator| collaborator.is_host)
1254            .ok_or_else(|| anyhow!("host not found"))?;
1255        let host_connection_id = ConnectionId(host.connection_id as u32);
1256        let response_payload = self
1257            .peer
1258            .forward_request(
1259                request.sender_connection_id,
1260                host_connection_id,
1261                request.payload.clone(),
1262            )
1263            .await?;
1264
1265        let mut collaborators = self
1266            .app_state
1267            .db
1268            .project_collaborators(project_id, request.sender_connection_id)
1269            .await?;
1270        collaborators.retain(|collaborator| {
1271            collaborator.connection_id != request.sender_connection_id.0 as i32
1272        });
1273        let project_connection_ids = collaborators
1274            .into_iter()
1275            .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1276        broadcast(host_connection_id, project_connection_ids, |conn_id| {
1277            self.peer
1278                .forward_send(host_connection_id, conn_id, response_payload.clone())
1279        });
1280        response.send(response_payload)?;
1281        Ok(())
1282    }
1283
1284    async fn create_buffer_for_peer(
1285        self: Arc<Server>,
1286        request: Message<proto::CreateBufferForPeer>,
1287    ) -> Result<()> {
1288        self.peer.forward_send(
1289            request.sender_connection_id,
1290            ConnectionId(request.payload.peer_id),
1291            request.payload,
1292        )?;
1293        Ok(())
1294    }
1295
1296    async fn update_buffer(
1297        self: Arc<Server>,
1298        request: Message<proto::UpdateBuffer>,
1299        response: Response<proto::UpdateBuffer>,
1300    ) -> Result<()> {
1301        let project_id = ProjectId::from_proto(request.payload.project_id);
1302        let project_connection_ids = self
1303            .app_state
1304            .db
1305            .project_connection_ids(project_id, request.sender_connection_id)
1306            .await?;
1307
1308        broadcast(
1309            request.sender_connection_id,
1310            project_connection_ids,
1311            |connection_id| {
1312                self.peer.forward_send(
1313                    request.sender_connection_id,
1314                    connection_id,
1315                    request.payload.clone(),
1316                )
1317            },
1318        );
1319        response.send(proto::Ack {})?;
1320        Ok(())
1321    }
1322
1323    async fn update_buffer_file(
1324        self: Arc<Server>,
1325        request: Message<proto::UpdateBufferFile>,
1326    ) -> Result<()> {
1327        let project_id = ProjectId::from_proto(request.payload.project_id);
1328        let project_connection_ids = self
1329            .app_state
1330            .db
1331            .project_connection_ids(project_id, request.sender_connection_id)
1332            .await?;
1333
1334        broadcast(
1335            request.sender_connection_id,
1336            project_connection_ids,
1337            |connection_id| {
1338                self.peer.forward_send(
1339                    request.sender_connection_id,
1340                    connection_id,
1341                    request.payload.clone(),
1342                )
1343            },
1344        );
1345        Ok(())
1346    }
1347
1348    async fn buffer_reloaded(
1349        self: Arc<Server>,
1350        request: Message<proto::BufferReloaded>,
1351    ) -> Result<()> {
1352        let project_id = ProjectId::from_proto(request.payload.project_id);
1353        let project_connection_ids = self
1354            .app_state
1355            .db
1356            .project_connection_ids(project_id, request.sender_connection_id)
1357            .await?;
1358        broadcast(
1359            request.sender_connection_id,
1360            project_connection_ids,
1361            |connection_id| {
1362                self.peer.forward_send(
1363                    request.sender_connection_id,
1364                    connection_id,
1365                    request.payload.clone(),
1366                )
1367            },
1368        );
1369        Ok(())
1370    }
1371
1372    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1373        let project_id = ProjectId::from_proto(request.payload.project_id);
1374        let project_connection_ids = self
1375            .app_state
1376            .db
1377            .project_connection_ids(project_id, request.sender_connection_id)
1378            .await?;
1379        broadcast(
1380            request.sender_connection_id,
1381            project_connection_ids,
1382            |connection_id| {
1383                self.peer.forward_send(
1384                    request.sender_connection_id,
1385                    connection_id,
1386                    request.payload.clone(),
1387                )
1388            },
1389        );
1390        Ok(())
1391    }
1392
1393    async fn follow(
1394        self: Arc<Self>,
1395        request: Message<proto::Follow>,
1396        response: Response<proto::Follow>,
1397    ) -> Result<()> {
1398        let project_id = ProjectId::from_proto(request.payload.project_id);
1399        let leader_id = ConnectionId(request.payload.leader_id);
1400        let follower_id = request.sender_connection_id;
1401        let project_connection_ids = self
1402            .app_state
1403            .db
1404            .project_connection_ids(project_id, request.sender_connection_id)
1405            .await?;
1406
1407        if !project_connection_ids.contains(&leader_id) {
1408            Err(anyhow!("no such peer"))?;
1409        }
1410
1411        let mut response_payload = self
1412            .peer
1413            .forward_request(request.sender_connection_id, leader_id, request.payload)
1414            .await?;
1415        response_payload
1416            .views
1417            .retain(|view| view.leader_id != Some(follower_id.0));
1418        response.send(response_payload)?;
1419        Ok(())
1420    }
1421
1422    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1423        let project_id = ProjectId::from_proto(request.payload.project_id);
1424        let leader_id = ConnectionId(request.payload.leader_id);
1425        let project_connection_ids = self
1426            .app_state
1427            .db
1428            .project_connection_ids(project_id, request.sender_connection_id)
1429            .await?;
1430        if !project_connection_ids.contains(&leader_id) {
1431            Err(anyhow!("no such peer"))?;
1432        }
1433        self.peer
1434            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1435        Ok(())
1436    }
1437
1438    async fn update_followers(
1439        self: Arc<Self>,
1440        request: Message<proto::UpdateFollowers>,
1441    ) -> Result<()> {
1442        let project_id = ProjectId::from_proto(request.payload.project_id);
1443        let project_connection_ids = self
1444            .app_state
1445            .db
1446            .project_connection_ids(project_id, request.sender_connection_id)
1447            .await?;
1448
1449        let leader_id = request
1450            .payload
1451            .variant
1452            .as_ref()
1453            .and_then(|variant| match variant {
1454                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1455                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1456                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1457            });
1458        for follower_id in &request.payload.follower_ids {
1459            let follower_id = ConnectionId(*follower_id);
1460            if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1461                self.peer.forward_send(
1462                    request.sender_connection_id,
1463                    follower_id,
1464                    request.payload.clone(),
1465                )?;
1466            }
1467        }
1468        Ok(())
1469    }
1470
1471    async fn get_users(
1472        self: Arc<Server>,
1473        request: Message<proto::GetUsers>,
1474        response: Response<proto::GetUsers>,
1475    ) -> Result<()> {
1476        let user_ids = request
1477            .payload
1478            .user_ids
1479            .into_iter()
1480            .map(UserId::from_proto)
1481            .collect();
1482        let users = self
1483            .app_state
1484            .db
1485            .get_users_by_ids(user_ids)
1486            .await?
1487            .into_iter()
1488            .map(|user| proto::User {
1489                id: user.id.to_proto(),
1490                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1491                github_login: user.github_login,
1492            })
1493            .collect();
1494        response.send(proto::UsersResponse { users })?;
1495        Ok(())
1496    }
1497
1498    async fn fuzzy_search_users(
1499        self: Arc<Server>,
1500        request: Message<proto::FuzzySearchUsers>,
1501        response: Response<proto::FuzzySearchUsers>,
1502    ) -> Result<()> {
1503        let query = request.payload.query;
1504        let db = &self.app_state.db;
1505        let users = match query.len() {
1506            0 => vec![],
1507            1 | 2 => db
1508                .get_user_by_github_account(&query, None)
1509                .await?
1510                .into_iter()
1511                .collect(),
1512            _ => db.fuzzy_search_users(&query, 10).await?,
1513        };
1514        let users = users
1515            .into_iter()
1516            .filter(|user| user.id != request.sender_user_id)
1517            .map(|user| proto::User {
1518                id: user.id.to_proto(),
1519                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1520                github_login: user.github_login,
1521            })
1522            .collect();
1523        response.send(proto::UsersResponse { users })?;
1524        Ok(())
1525    }
1526
1527    async fn request_contact(
1528        self: Arc<Server>,
1529        request: Message<proto::RequestContact>,
1530        response: Response<proto::RequestContact>,
1531    ) -> Result<()> {
1532        let requester_id = request.sender_user_id;
1533        let responder_id = UserId::from_proto(request.payload.responder_id);
1534        if requester_id == responder_id {
1535            return Err(anyhow!("cannot add yourself as a contact"))?;
1536        }
1537
1538        self.app_state
1539            .db
1540            .send_contact_request(requester_id, responder_id)
1541            .await?;
1542
1543        // Update outgoing contact requests of requester
1544        let mut update = proto::UpdateContacts::default();
1545        update.outgoing_requests.push(responder_id.to_proto());
1546        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1547            self.peer.send(connection_id, update.clone())?;
1548        }
1549
1550        // Update incoming contact requests of responder
1551        let mut update = proto::UpdateContacts::default();
1552        update
1553            .incoming_requests
1554            .push(proto::IncomingContactRequest {
1555                requester_id: requester_id.to_proto(),
1556                should_notify: true,
1557            });
1558        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1559            self.peer.send(connection_id, update.clone())?;
1560        }
1561
1562        response.send(proto::Ack {})?;
1563        Ok(())
1564    }
1565
1566    async fn respond_to_contact_request(
1567        self: Arc<Server>,
1568        request: Message<proto::RespondToContactRequest>,
1569        response: Response<proto::RespondToContactRequest>,
1570    ) -> Result<()> {
1571        let responder_id = request.sender_user_id;
1572        let requester_id = UserId::from_proto(request.payload.requester_id);
1573        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1574            self.app_state
1575                .db
1576                .dismiss_contact_notification(responder_id, requester_id)
1577                .await?;
1578        } else {
1579            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1580            self.app_state
1581                .db
1582                .respond_to_contact_request(responder_id, requester_id, accept)
1583                .await?;
1584            let busy = self.app_state.db.is_user_busy(requester_id).await?;
1585
1586            let store = self.store().await;
1587            // Update responder with new contact
1588            let mut update = proto::UpdateContacts::default();
1589            if accept {
1590                update
1591                    .contacts
1592                    .push(store.contact_for_user(requester_id, false, busy));
1593            }
1594            update
1595                .remove_incoming_requests
1596                .push(requester_id.to_proto());
1597            for connection_id in store.connection_ids_for_user(responder_id) {
1598                self.peer.send(connection_id, update.clone())?;
1599            }
1600
1601            // Update requester with new contact
1602            let mut update = proto::UpdateContacts::default();
1603            if accept {
1604                update
1605                    .contacts
1606                    .push(store.contact_for_user(responder_id, true, busy));
1607            }
1608            update
1609                .remove_outgoing_requests
1610                .push(responder_id.to_proto());
1611            for connection_id in store.connection_ids_for_user(requester_id) {
1612                self.peer.send(connection_id, update.clone())?;
1613            }
1614        }
1615
1616        response.send(proto::Ack {})?;
1617        Ok(())
1618    }
1619
1620    async fn remove_contact(
1621        self: Arc<Server>,
1622        request: Message<proto::RemoveContact>,
1623        response: Response<proto::RemoveContact>,
1624    ) -> Result<()> {
1625        let requester_id = request.sender_user_id;
1626        let responder_id = UserId::from_proto(request.payload.user_id);
1627        self.app_state
1628            .db
1629            .remove_contact(requester_id, responder_id)
1630            .await?;
1631
1632        // Update outgoing contact requests of requester
1633        let mut update = proto::UpdateContacts::default();
1634        update
1635            .remove_outgoing_requests
1636            .push(responder_id.to_proto());
1637        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1638            self.peer.send(connection_id, update.clone())?;
1639        }
1640
1641        // Update incoming contact requests of responder
1642        let mut update = proto::UpdateContacts::default();
1643        update
1644            .remove_incoming_requests
1645            .push(requester_id.to_proto());
1646        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1647            self.peer.send(connection_id, update.clone())?;
1648        }
1649
1650        response.send(proto::Ack {})?;
1651        Ok(())
1652    }
1653
1654    async fn update_diff_base(
1655        self: Arc<Server>,
1656        request: Message<proto::UpdateDiffBase>,
1657    ) -> Result<()> {
1658        let project_id = ProjectId::from_proto(request.payload.project_id);
1659        let project_connection_ids = self
1660            .app_state
1661            .db
1662            .project_connection_ids(project_id, request.sender_connection_id)
1663            .await?;
1664        broadcast(
1665            request.sender_connection_id,
1666            project_connection_ids,
1667            |connection_id| {
1668                self.peer.forward_send(
1669                    request.sender_connection_id,
1670                    connection_id,
1671                    request.payload.clone(),
1672                )
1673            },
1674        );
1675        Ok(())
1676    }
1677
1678    async fn get_private_user_info(
1679        self: Arc<Self>,
1680        request: Message<proto::GetPrivateUserInfo>,
1681        response: Response<proto::GetPrivateUserInfo>,
1682    ) -> Result<()> {
1683        let metrics_id = self
1684            .app_state
1685            .db
1686            .get_user_metrics_id(request.sender_user_id)
1687            .await?;
1688        let user = self
1689            .app_state
1690            .db
1691            .get_user_by_id(request.sender_user_id)
1692            .await?
1693            .ok_or_else(|| anyhow!("user not found"))?;
1694        response.send(proto::GetPrivateUserInfoResponse {
1695            metrics_id,
1696            staff: user.admin,
1697        })?;
1698        Ok(())
1699    }
1700
1701    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1702        #[cfg(test)]
1703        tokio::task::yield_now().await;
1704        let guard = self.store.lock().await;
1705        #[cfg(test)]
1706        tokio::task::yield_now().await;
1707        StoreGuard {
1708            guard,
1709            _not_send: PhantomData,
1710        }
1711    }
1712
1713    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1714        ServerSnapshot {
1715            store: self.store().await,
1716            peer: &self.peer,
1717        }
1718    }
1719}
1720
1721impl<'a> Deref for StoreGuard<'a> {
1722    type Target = Store;
1723
1724    fn deref(&self) -> &Self::Target {
1725        &*self.guard
1726    }
1727}
1728
1729impl<'a> DerefMut for StoreGuard<'a> {
1730    fn deref_mut(&mut self) -> &mut Self::Target {
1731        &mut *self.guard
1732    }
1733}
1734
1735impl<'a> Drop for StoreGuard<'a> {
1736    fn drop(&mut self) {
1737        #[cfg(test)]
1738        self.check_invariants();
1739    }
1740}
1741
1742impl Executor for RealExecutor {
1743    type Sleep = Sleep;
1744
1745    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1746        tokio::task::spawn(future);
1747    }
1748
1749    fn sleep(&self, duration: Duration) -> Self::Sleep {
1750        tokio::time::sleep(duration)
1751    }
1752}
1753
1754fn broadcast<F>(
1755    sender_id: ConnectionId,
1756    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1757    mut f: F,
1758) where
1759    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1760{
1761    for receiver_id in receiver_ids {
1762        if receiver_id != sender_id {
1763            f(receiver_id).trace_err();
1764        }
1765    }
1766}
1767
1768lazy_static! {
1769    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1770}
1771
1772pub struct ProtocolVersion(u32);
1773
1774impl Header for ProtocolVersion {
1775    fn name() -> &'static HeaderName {
1776        &ZED_PROTOCOL_VERSION
1777    }
1778
1779    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1780    where
1781        Self: Sized,
1782        I: Iterator<Item = &'i axum::http::HeaderValue>,
1783    {
1784        let version = values
1785            .next()
1786            .ok_or_else(axum::headers::Error::invalid)?
1787            .to_str()
1788            .map_err(|_| axum::headers::Error::invalid())?
1789            .parse()
1790            .map_err(|_| axum::headers::Error::invalid())?;
1791        Ok(Self(version))
1792    }
1793
1794    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1795        values.extend([self.0.to_string().parse().unwrap()]);
1796    }
1797}
1798
1799pub fn routes(server: Arc<Server>) -> Router<Body> {
1800    Router::new()
1801        .route("/rpc", get(handle_websocket_request))
1802        .layer(
1803            ServiceBuilder::new()
1804                .layer(Extension(server.app_state.clone()))
1805                .layer(middleware::from_fn(auth::validate_header)),
1806        )
1807        .route("/metrics", get(handle_metrics))
1808        .layer(Extension(server))
1809}
1810
1811pub async fn handle_websocket_request(
1812    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1813    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1814    Extension(server): Extension<Arc<Server>>,
1815    Extension(user): Extension<User>,
1816    ws: WebSocketUpgrade,
1817) -> axum::response::Response {
1818    if protocol_version != rpc::PROTOCOL_VERSION {
1819        return (
1820            StatusCode::UPGRADE_REQUIRED,
1821            "client must be upgraded".to_string(),
1822        )
1823            .into_response();
1824    }
1825    let socket_address = socket_address.to_string();
1826    ws.on_upgrade(move |socket| {
1827        use util::ResultExt;
1828        let socket = socket
1829            .map_ok(to_tungstenite_message)
1830            .err_into()
1831            .with(|message| async move { Ok(to_axum_message(message)) });
1832        let connection = Connection::new(Box::pin(socket));
1833        async move {
1834            server
1835                .handle_connection(connection, socket_address, user, None, RealExecutor)
1836                .await
1837                .log_err();
1838        }
1839    })
1840}
1841
1842pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1843    let metrics = server.store().await.metrics();
1844    METRIC_CONNECTIONS.set(metrics.connections as _);
1845    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1846
1847    let encoder = prometheus::TextEncoder::new();
1848    let metric_families = prometheus::gather();
1849    match encoder.encode_to_string(&metric_families) {
1850        Ok(string) => (StatusCode::OK, string).into_response(),
1851        Err(error) => (
1852            StatusCode::INTERNAL_SERVER_ERROR,
1853            format!("failed to encode metrics {:?}", error),
1854        )
1855            .into_response(),
1856    }
1857}
1858
1859fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1860    match message {
1861        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1862        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1863        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1864        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1865        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1866            code: frame.code.into(),
1867            reason: frame.reason,
1868        })),
1869    }
1870}
1871
1872fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1873    match message {
1874        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1875        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1876        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1877        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1878        AxumMessage::Close(frame) => {
1879            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1880                code: frame.code.into(),
1881                reason: frame.reason,
1882            }))
1883        }
1884    }
1885}
1886
1887pub trait ResultExt {
1888    type Ok;
1889
1890    fn trace_err(self) -> Option<Self::Ok>;
1891}
1892
1893impl<T, E> ResultExt for Result<T, E>
1894where
1895    E: std::fmt::Debug,
1896{
1897    type Ok = T;
1898
1899    fn trace_err(self) -> Option<T> {
1900        match self {
1901            Ok(value) => Some(value),
1902            Err(error) => {
1903                tracing::error!("{:?}", error);
1904                None
1905            }
1906        }
1907    }
1908}