rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    os::unix::prelude::OsStrExt,
  46    rc::Rc,
  47    sync::{
  48        atomic::{AtomicBool, Ordering::SeqCst},
  49        Arc,
  50    },
  51    time::Duration,
  52};
  53pub use store::{Store, Worktree};
  54use tokio::{
  55    sync::{Mutex, MutexGuard},
  56    time::Sleep,
  57};
  58use tower::ServiceBuilder;
  59use tracing::{info_span, instrument, Instrument};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  65        "shared_projects",
  66        "number of open projects with one or more guests"
  67    )
  68    .unwrap();
  69}
  70
  71type MessageHandler = Box<
  72    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  73>;
  74
  75struct Message<T> {
  76    sender_user_id: UserId,
  77    sender_connection_id: ConnectionId,
  78    payload: T,
  79}
  80
  81struct Response<R> {
  82    server: Arc<Server>,
  83    receipt: Receipt<R>,
  84    responded: Arc<AtomicBool>,
  85}
  86
  87impl<R: RequestMessage> Response<R> {
  88    fn send(self, payload: R::Response) -> Result<()> {
  89        self.responded.store(true, SeqCst);
  90        self.server.peer.respond(self.receipt, payload)?;
  91        Ok(())
  92    }
  93}
  94
  95pub struct Server {
  96    peer: Arc<Peer>,
  97    pub(crate) store: Mutex<Store>,
  98    app_state: Arc<AppState>,
  99    handlers: HashMap<TypeId, MessageHandler>,
 100}
 101
 102pub trait Executor: Send + Clone {
 103    type Sleep: Send + Future;
 104    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 105    fn sleep(&self, duration: Duration) -> Self::Sleep;
 106}
 107
 108#[derive(Clone)]
 109pub struct RealExecutor;
 110
 111pub(crate) struct StoreGuard<'a> {
 112    guard: MutexGuard<'a, Store>,
 113    _not_send: PhantomData<Rc<()>>,
 114}
 115
 116#[derive(Serialize)]
 117pub struct ServerSnapshot<'a> {
 118    peer: &'a Peer,
 119    #[serde(serialize_with = "serialize_deref")]
 120    store: StoreGuard<'a>,
 121}
 122
 123pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 124where
 125    S: Serializer,
 126    T: Deref<Target = U>,
 127    U: Serialize,
 128{
 129    Serialize::serialize(value.deref(), serializer)
 130}
 131
 132impl Server {
 133    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 134        let mut server = Self {
 135            peer: Peer::new(),
 136            app_state,
 137            store: Default::default(),
 138            handlers: Default::default(),
 139        };
 140
 141        server
 142            .add_request_handler(Server::ping)
 143            .add_request_handler(Server::create_room)
 144            .add_request_handler(Server::join_room)
 145            .add_message_handler(Server::leave_room)
 146            .add_request_handler(Server::call)
 147            .add_request_handler(Server::cancel_call)
 148            .add_message_handler(Server::decline_call)
 149            .add_request_handler(Server::update_participant_location)
 150            .add_request_handler(Server::share_project)
 151            .add_message_handler(Server::unshare_project)
 152            .add_request_handler(Server::join_project)
 153            .add_message_handler(Server::leave_project)
 154            .add_request_handler(Server::update_project)
 155            .add_request_handler(Server::update_worktree)
 156            .add_message_handler(Server::start_language_server)
 157            .add_message_handler(Server::update_language_server)
 158            .add_message_handler(Server::update_diagnostic_summary)
 159            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 162            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 163            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 165            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 168            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 169            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 170            .add_request_handler(
 171                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 172            )
 173            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 174            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 175            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 177            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 179            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 182            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 183            .add_message_handler(Server::create_buffer_for_peer)
 184            .add_request_handler(Server::update_buffer)
 185            .add_message_handler(Server::update_buffer_file)
 186            .add_message_handler(Server::buffer_reloaded)
 187            .add_message_handler(Server::buffer_saved)
 188            .add_request_handler(Server::save_buffer)
 189            .add_request_handler(Server::get_users)
 190            .add_request_handler(Server::fuzzy_search_users)
 191            .add_request_handler(Server::request_contact)
 192            .add_request_handler(Server::remove_contact)
 193            .add_request_handler(Server::respond_to_contact_request)
 194            .add_request_handler(Server::follow)
 195            .add_message_handler(Server::unfollow)
 196            .add_message_handler(Server::update_followers)
 197            .add_message_handler(Server::update_diff_base)
 198            .add_request_handler(Server::get_private_user_info);
 199
 200        Arc::new(server)
 201    }
 202
 203    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 204    where
 205        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 206        Fut: 'static + Send + Future<Output = Result<()>>,
 207        M: EnvelopedMessage,
 208    {
 209        let prev_handler = self.handlers.insert(
 210            TypeId::of::<M>(),
 211            Box::new(move |server, sender_user_id, envelope| {
 212                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 213                let span = info_span!(
 214                    "handle message",
 215                    payload_type = envelope.payload_type_name()
 216                );
 217                span.in_scope(|| {
 218                    tracing::info!(
 219                        payload_type = envelope.payload_type_name(),
 220                        "message received"
 221                    );
 222                });
 223                let future = (handler)(server, sender_user_id, *envelope);
 224                async move {
 225                    if let Err(error) = future.await {
 226                        tracing::error!(%error, "error handling message");
 227                    }
 228                }
 229                .instrument(span)
 230                .boxed()
 231            }),
 232        );
 233        if prev_handler.is_some() {
 234            panic!("registered a handler for the same message twice");
 235        }
 236        self
 237    }
 238
 239    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 240    where
 241        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 242        Fut: 'static + Send + Future<Output = Result<()>>,
 243        M: EnvelopedMessage,
 244    {
 245        self.add_handler(move |server, sender_user_id, envelope| {
 246            handler(
 247                server,
 248                Message {
 249                    sender_user_id,
 250                    sender_connection_id: envelope.sender_id,
 251                    payload: envelope.payload,
 252                },
 253            )
 254        });
 255        self
 256    }
 257
 258    /// Handle a request while holding a lock to the store. This is useful when we're registering
 259    /// a connection but we want to respond on the connection before anybody else can send on it.
 260    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 261    where
 262        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 263        Fut: Send + Future<Output = Result<()>>,
 264        M: RequestMessage,
 265    {
 266        let handler = Arc::new(handler);
 267        self.add_handler(move |server, sender_user_id, envelope| {
 268            let receipt = envelope.receipt();
 269            let handler = handler.clone();
 270            async move {
 271                let request = Message {
 272                    sender_user_id,
 273                    sender_connection_id: envelope.sender_id,
 274                    payload: envelope.payload,
 275                };
 276                let responded = Arc::new(AtomicBool::default());
 277                let response = Response {
 278                    server: server.clone(),
 279                    responded: responded.clone(),
 280                    receipt,
 281                };
 282                match (handler)(server.clone(), request, response).await {
 283                    Ok(()) => {
 284                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 285                            Ok(())
 286                        } else {
 287                            Err(anyhow!("handler did not send a response"))?
 288                        }
 289                    }
 290                    Err(error) => {
 291                        server.peer.respond_with_error(
 292                            receipt,
 293                            proto::Error {
 294                                message: error.to_string(),
 295                            },
 296                        )?;
 297                        Err(error)
 298                    }
 299                }
 300            }
 301        })
 302    }
 303
 304    pub fn handle_connection<E: Executor>(
 305        self: &Arc<Self>,
 306        connection: Connection,
 307        address: String,
 308        user: User,
 309        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 310        executor: E,
 311    ) -> impl Future<Output = Result<()>> {
 312        let mut this = self.clone();
 313        let user_id = user.id;
 314        let login = user.github_login;
 315        let span = info_span!("handle connection", %user_id, %login, %address);
 316        async move {
 317            let (connection_id, handle_io, mut incoming_rx) = this
 318                .peer
 319                .add_connection(connection, {
 320                    let executor = executor.clone();
 321                    move |duration| {
 322                        let timer = executor.sleep(duration);
 323                        async move {
 324                            timer.await;
 325                        }
 326                    }
 327                });
 328
 329            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 330            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 331            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 332
 333            if let Some(send_connection_id) = send_connection_id.take() {
 334                let _ = send_connection_id.send(connection_id);
 335            }
 336
 337            if !user.connected_once {
 338                this.peer.send(connection_id, proto::ShowContacts {})?;
 339                this.app_state.db.set_user_connected_once(user_id, true).await?;
 340            }
 341
 342            let (contacts, invite_code) = future::try_join(
 343                this.app_state.db.get_contacts(user_id),
 344                this.app_state.db.get_invite_code_for_user(user_id)
 345            ).await?;
 346
 347            {
 348                let mut store = this.store().await;
 349                store.add_connection(connection_id, user_id, user.admin);
 350                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 351
 352                if let Some((code, count)) = invite_code {
 353                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 354                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 355                        count,
 356                    })?;
 357                }
 358            }
 359
 360            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 361                this.peer.send(connection_id, incoming_call)?;
 362            }
 363
 364            this.update_user_contacts(user_id).await?;
 365
 366            let handle_io = handle_io.fuse();
 367            futures::pin_mut!(handle_io);
 368
 369            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 370            // This prevents deadlocks when e.g., client A performs a request to client B and
 371            // client B performs a request to client A. If both clients stop processing further
 372            // messages until their respective request completes, they won't have a chance to
 373            // respond to the other client's request and cause a deadlock.
 374            //
 375            // This arrangement ensures we will attempt to process earlier messages first, but fall
 376            // back to processing messages arrived later in the spirit of making progress.
 377            let mut foreground_message_handlers = FuturesUnordered::new();
 378            loop {
 379                let next_message = incoming_rx.next().fuse();
 380                futures::pin_mut!(next_message);
 381                futures::select_biased! {
 382                    result = handle_io => {
 383                        if let Err(error) = result {
 384                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 385                        }
 386                        break;
 387                    }
 388                    _ = foreground_message_handlers.next() => {}
 389                    message = next_message => {
 390                        if let Some(message) = message {
 391                            let type_name = message.payload_type_name();
 392                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 393                            let span_enter = span.enter();
 394                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 395                                let is_background = message.is_background();
 396                                let handle_message = (handler)(this.clone(), user_id, message);
 397                                drop(span_enter);
 398
 399                                let handle_message = handle_message.instrument(span);
 400                                if is_background {
 401                                    executor.spawn_detached(handle_message);
 402                                } else {
 403                                    foreground_message_handlers.push(handle_message);
 404                                }
 405                            } else {
 406                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 407                            }
 408                        } else {
 409                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 410                            break;
 411                        }
 412                    }
 413                }
 414            }
 415
 416            drop(foreground_message_handlers);
 417            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 418            if let Err(error) = this.sign_out(connection_id, user_id).await {
 419                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 420            }
 421
 422            Ok(())
 423        }.instrument(span)
 424    }
 425
 426    #[instrument(skip(self), err)]
 427    async fn sign_out(
 428        self: &mut Arc<Self>,
 429        connection_id: ConnectionId,
 430        user_id: UserId,
 431    ) -> Result<()> {
 432        self.peer.disconnect(connection_id);
 433        let decline_calls = {
 434            let mut store = self.store().await;
 435            store.remove_connection(connection_id)?;
 436            let mut connections = store.connection_ids_for_user(user_id);
 437            connections.next().is_none()
 438        };
 439
 440        self.leave_room_for_connection(connection_id, user_id)
 441            .await
 442            .trace_err();
 443        if decline_calls {
 444            if let Some(room) = self
 445                .app_state
 446                .db
 447                .decline_call(None, user_id)
 448                .await
 449                .trace_err()
 450            {
 451                self.room_updated(&room);
 452            }
 453        }
 454
 455        self.update_user_contacts(user_id).await?;
 456
 457        Ok(())
 458    }
 459
 460    pub async fn invite_code_redeemed(
 461        self: &Arc<Self>,
 462        inviter_id: UserId,
 463        invitee_id: UserId,
 464    ) -> Result<()> {
 465        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 466            if let Some(code) = &user.invite_code {
 467                let store = self.store().await;
 468                let invitee_contact = store.contact_for_user(invitee_id, true);
 469                for connection_id in store.connection_ids_for_user(inviter_id) {
 470                    self.peer.send(
 471                        connection_id,
 472                        proto::UpdateContacts {
 473                            contacts: vec![invitee_contact.clone()],
 474                            ..Default::default()
 475                        },
 476                    )?;
 477                    self.peer.send(
 478                        connection_id,
 479                        proto::UpdateInviteInfo {
 480                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 481                            count: user.invite_count as u32,
 482                        },
 483                    )?;
 484                }
 485            }
 486        }
 487        Ok(())
 488    }
 489
 490    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 491        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 492            if let Some(invite_code) = &user.invite_code {
 493                let store = self.store().await;
 494                for connection_id in store.connection_ids_for_user(user_id) {
 495                    self.peer.send(
 496                        connection_id,
 497                        proto::UpdateInviteInfo {
 498                            url: format!(
 499                                "{}{}",
 500                                self.app_state.config.invite_link_prefix, invite_code
 501                            ),
 502                            count: user.invite_count as u32,
 503                        },
 504                    )?;
 505                }
 506            }
 507        }
 508        Ok(())
 509    }
 510
 511    async fn ping(
 512        self: Arc<Server>,
 513        _: Message<proto::Ping>,
 514        response: Response<proto::Ping>,
 515    ) -> Result<()> {
 516        response.send(proto::Ack {})?;
 517        Ok(())
 518    }
 519
 520    async fn create_room(
 521        self: Arc<Server>,
 522        request: Message<proto::CreateRoom>,
 523        response: Response<proto::CreateRoom>,
 524    ) -> Result<()> {
 525        let room = self
 526            .app_state
 527            .db
 528            .create_room(request.sender_user_id, request.sender_connection_id)
 529            .await?;
 530
 531        let live_kit_connection_info =
 532            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 533                if let Some(_) = live_kit
 534                    .create_room(room.live_kit_room.clone())
 535                    .await
 536                    .trace_err()
 537                {
 538                    if let Some(token) = live_kit
 539                        .room_token(
 540                            &room.live_kit_room,
 541                            &request.sender_connection_id.to_string(),
 542                        )
 543                        .trace_err()
 544                    {
 545                        Some(proto::LiveKitConnectionInfo {
 546                            server_url: live_kit.url().into(),
 547                            token,
 548                        })
 549                    } else {
 550                        None
 551                    }
 552                } else {
 553                    None
 554                }
 555            } else {
 556                None
 557            };
 558
 559        response.send(proto::CreateRoomResponse {
 560            room: Some(room),
 561            live_kit_connection_info,
 562        })?;
 563        self.update_user_contacts(request.sender_user_id).await?;
 564        Ok(())
 565    }
 566
 567    async fn join_room(
 568        self: Arc<Server>,
 569        request: Message<proto::JoinRoom>,
 570        response: Response<proto::JoinRoom>,
 571    ) -> Result<()> {
 572        let room = self
 573            .app_state
 574            .db
 575            .join_room(
 576                RoomId::from_proto(request.payload.id),
 577                request.sender_user_id,
 578                request.sender_connection_id,
 579            )
 580            .await?;
 581        for connection_id in self
 582            .store()
 583            .await
 584            .connection_ids_for_user(request.sender_user_id)
 585        {
 586            self.peer
 587                .send(connection_id, proto::CallCanceled {})
 588                .trace_err();
 589        }
 590
 591        let live_kit_connection_info =
 592            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 593                if let Some(token) = live_kit
 594                    .room_token(
 595                        &room.live_kit_room,
 596                        &request.sender_connection_id.to_string(),
 597                    )
 598                    .trace_err()
 599                {
 600                    Some(proto::LiveKitConnectionInfo {
 601                        server_url: live_kit.url().into(),
 602                        token,
 603                    })
 604                } else {
 605                    None
 606                }
 607            } else {
 608                None
 609            };
 610
 611        self.room_updated(&room);
 612        response.send(proto::JoinRoomResponse {
 613            room: Some(room),
 614            live_kit_connection_info,
 615        })?;
 616
 617        self.update_user_contacts(request.sender_user_id).await?;
 618        Ok(())
 619    }
 620
 621    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 622        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
 623            .await
 624    }
 625
 626    async fn leave_room_for_connection(
 627        self: &Arc<Server>,
 628        connection_id: ConnectionId,
 629        user_id: UserId,
 630    ) -> Result<()> {
 631        let mut contacts_to_update = HashSet::default();
 632
 633        let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else {
 634            return Err(anyhow!("no room to leave"))?;
 635        };
 636        contacts_to_update.insert(user_id);
 637
 638        for project in left_room.left_projects.into_values() {
 639            if project.host_user_id == user_id {
 640                for connection_id in project.connection_ids {
 641                    self.peer
 642                        .send(
 643                            connection_id,
 644                            proto::UnshareProject {
 645                                project_id: project.id.to_proto(),
 646                            },
 647                        )
 648                        .trace_err();
 649                }
 650            } else {
 651                for connection_id in project.connection_ids {
 652                    self.peer
 653                        .send(
 654                            connection_id,
 655                            proto::RemoveProjectCollaborator {
 656                                project_id: project.id.to_proto(),
 657                                peer_id: connection_id.0,
 658                            },
 659                        )
 660                        .trace_err();
 661                }
 662
 663                self.peer
 664                    .send(
 665                        connection_id,
 666                        proto::UnshareProject {
 667                            project_id: project.id.to_proto(),
 668                        },
 669                    )
 670                    .trace_err();
 671            }
 672        }
 673
 674        self.room_updated(&left_room.room);
 675        {
 676            let store = self.store().await;
 677            for canceled_user_id in left_room.canceled_calls_to_user_ids {
 678                for connection_id in store.connection_ids_for_user(canceled_user_id) {
 679                    self.peer
 680                        .send(connection_id, proto::CallCanceled {})
 681                        .trace_err();
 682                }
 683                contacts_to_update.insert(canceled_user_id);
 684            }
 685        }
 686
 687        for contact_user_id in contacts_to_update {
 688            self.update_user_contacts(contact_user_id).await?;
 689        }
 690
 691        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 692            live_kit
 693                .remove_participant(
 694                    left_room.room.live_kit_room.clone(),
 695                    connection_id.to_string(),
 696                )
 697                .await
 698                .trace_err();
 699
 700            if left_room.room.participants.is_empty() {
 701                live_kit
 702                    .delete_room(left_room.room.live_kit_room)
 703                    .await
 704                    .trace_err();
 705            }
 706        }
 707
 708        Ok(())
 709    }
 710
 711    async fn call(
 712        self: Arc<Server>,
 713        request: Message<proto::Call>,
 714        response: Response<proto::Call>,
 715    ) -> Result<()> {
 716        let room_id = RoomId::from_proto(request.payload.room_id);
 717        let calling_user_id = request.sender_user_id;
 718        let calling_connection_id = request.sender_connection_id;
 719        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 720        let initial_project_id = request
 721            .payload
 722            .initial_project_id
 723            .map(ProjectId::from_proto);
 724        if !self
 725            .app_state
 726            .db
 727            .has_contact(calling_user_id, called_user_id)
 728            .await?
 729        {
 730            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 731        }
 732
 733        let (room, incoming_call) = self
 734            .app_state
 735            .db
 736            .call(
 737                room_id,
 738                calling_user_id,
 739                calling_connection_id,
 740                called_user_id,
 741                initial_project_id,
 742            )
 743            .await?;
 744        self.room_updated(&room);
 745        self.update_user_contacts(called_user_id).await?;
 746
 747        let mut calls = self
 748            .store()
 749            .await
 750            .connection_ids_for_user(called_user_id)
 751            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 752            .collect::<FuturesUnordered<_>>();
 753
 754        while let Some(call_response) = calls.next().await {
 755            match call_response.as_ref() {
 756                Ok(_) => {
 757                    response.send(proto::Ack {})?;
 758                    return Ok(());
 759                }
 760                Err(_) => {
 761                    call_response.trace_err();
 762                }
 763            }
 764        }
 765
 766        let room = self
 767            .app_state
 768            .db
 769            .call_failed(room_id, called_user_id)
 770            .await?;
 771        self.room_updated(&room);
 772        self.update_user_contacts(called_user_id).await?;
 773
 774        Err(anyhow!("failed to ring user"))?
 775    }
 776
 777    async fn cancel_call(
 778        self: Arc<Server>,
 779        request: Message<proto::CancelCall>,
 780        response: Response<proto::CancelCall>,
 781    ) -> Result<()> {
 782        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 783        let room_id = RoomId::from_proto(request.payload.room_id);
 784        let room = self
 785            .app_state
 786            .db
 787            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
 788            .await?;
 789        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
 790            self.peer
 791                .send(connection_id, proto::CallCanceled {})
 792                .trace_err();
 793        }
 794        self.room_updated(&room);
 795        response.send(proto::Ack {})?;
 796
 797        self.update_user_contacts(called_user_id).await?;
 798        Ok(())
 799    }
 800
 801    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 802        let room_id = RoomId::from_proto(message.payload.room_id);
 803        let room = self
 804            .app_state
 805            .db
 806            .decline_call(Some(room_id), message.sender_user_id)
 807            .await?;
 808        for connection_id in self
 809            .store()
 810            .await
 811            .connection_ids_for_user(message.sender_user_id)
 812        {
 813            self.peer
 814                .send(connection_id, proto::CallCanceled {})
 815                .trace_err();
 816        }
 817        self.room_updated(&room);
 818        self.update_user_contacts(message.sender_user_id).await?;
 819        Ok(())
 820    }
 821
 822    async fn update_participant_location(
 823        self: Arc<Server>,
 824        request: Message<proto::UpdateParticipantLocation>,
 825        response: Response<proto::UpdateParticipantLocation>,
 826    ) -> Result<()> {
 827        let room_id = RoomId::from_proto(request.payload.room_id);
 828        let location = request
 829            .payload
 830            .location
 831            .ok_or_else(|| anyhow!("invalid location"))?;
 832        let room = self
 833            .app_state
 834            .db
 835            .update_room_participant_location(room_id, request.sender_connection_id, location)
 836            .await?;
 837        self.room_updated(&room);
 838        response.send(proto::Ack {})?;
 839        Ok(())
 840    }
 841
 842    fn room_updated(&self, room: &proto::Room) {
 843        for participant in &room.participants {
 844            self.peer
 845                .send(
 846                    ConnectionId(participant.peer_id),
 847                    proto::RoomUpdated {
 848                        room: Some(room.clone()),
 849                    },
 850                )
 851                .trace_err();
 852        }
 853    }
 854
 855    async fn share_project(
 856        self: Arc<Server>,
 857        request: Message<proto::ShareProject>,
 858        response: Response<proto::ShareProject>,
 859    ) -> Result<()> {
 860        let (project_id, room) = self
 861            .app_state
 862            .db
 863            .share_project(
 864                RoomId::from_proto(request.payload.room_id),
 865                request.sender_user_id,
 866                request.sender_connection_id,
 867                &request.payload.worktrees,
 868            )
 869            .await
 870            .unwrap();
 871        response.send(proto::ShareProjectResponse {
 872            project_id: project_id.to_proto(),
 873        })?;
 874        self.room_updated(&room);
 875
 876        Ok(())
 877    }
 878
 879    async fn unshare_project(
 880        self: Arc<Server>,
 881        message: Message<proto::UnshareProject>,
 882    ) -> Result<()> {
 883        let project_id = ProjectId::from_proto(message.payload.project_id);
 884        let mut store = self.store().await;
 885        let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?;
 886        broadcast(
 887            message.sender_connection_id,
 888            project.guest_connection_ids(),
 889            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 890        );
 891        self.room_updated(room);
 892
 893        Ok(())
 894    }
 895
 896    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 897        let contacts = self.app_state.db.get_contacts(user_id).await?;
 898        let store = self.store().await;
 899        let updated_contact = store.contact_for_user(user_id, false);
 900        for contact in contacts {
 901            if let db::Contact::Accepted {
 902                user_id: contact_user_id,
 903                ..
 904            } = contact
 905            {
 906                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 907                    self.peer
 908                        .send(
 909                            contact_conn_id,
 910                            proto::UpdateContacts {
 911                                contacts: vec![updated_contact.clone()],
 912                                remove_contacts: Default::default(),
 913                                incoming_requests: Default::default(),
 914                                remove_incoming_requests: Default::default(),
 915                                outgoing_requests: Default::default(),
 916                                remove_outgoing_requests: Default::default(),
 917                            },
 918                        )
 919                        .trace_err();
 920                }
 921            }
 922        }
 923        Ok(())
 924    }
 925
 926    async fn join_project(
 927        self: Arc<Server>,
 928        request: Message<proto::JoinProject>,
 929        response: Response<proto::JoinProject>,
 930    ) -> Result<()> {
 931        let project_id = ProjectId::from_proto(request.payload.project_id);
 932        let guest_user_id = request.sender_user_id;
 933        let host_user_id;
 934        let host_connection_id;
 935        {
 936            let state = self.store().await;
 937            let project = state.project(project_id)?;
 938            host_user_id = project.host.user_id;
 939            host_connection_id = project.host_connection_id;
 940        };
 941
 942        tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 943
 944        let mut store = self.store().await;
 945        let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?;
 946        let peer_count = project.guests.len();
 947        let mut collaborators = Vec::with_capacity(peer_count);
 948        collaborators.push(proto::Collaborator {
 949            peer_id: project.host_connection_id.0,
 950            replica_id: 0,
 951            user_id: project.host.user_id.to_proto(),
 952        });
 953        let worktrees = project
 954            .worktrees
 955            .iter()
 956            .map(|(id, worktree)| proto::WorktreeMetadata {
 957                id: *id,
 958                root_name: worktree.root_name.clone(),
 959                visible: worktree.visible,
 960                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
 961            })
 962            .collect::<Vec<_>>();
 963
 964        // Add all guests other than the requesting user's own connections as collaborators
 965        for (guest_conn_id, guest) in &project.guests {
 966            if request.sender_connection_id != *guest_conn_id {
 967                collaborators.push(proto::Collaborator {
 968                    peer_id: guest_conn_id.0,
 969                    replica_id: guest.replica_id as u32,
 970                    user_id: guest.user_id.to_proto(),
 971                });
 972            }
 973        }
 974
 975        for conn_id in project.connection_ids() {
 976            if conn_id != request.sender_connection_id {
 977                self.peer
 978                    .send(
 979                        conn_id,
 980                        proto::AddProjectCollaborator {
 981                            project_id: project_id.to_proto(),
 982                            collaborator: Some(proto::Collaborator {
 983                                peer_id: request.sender_connection_id.0,
 984                                replica_id: replica_id as u32,
 985                                user_id: guest_user_id.to_proto(),
 986                            }),
 987                        },
 988                    )
 989                    .trace_err();
 990            }
 991        }
 992
 993        // First, we send the metadata associated with each worktree.
 994        response.send(proto::JoinProjectResponse {
 995            worktrees: worktrees.clone(),
 996            replica_id: replica_id as u32,
 997            collaborators: collaborators.clone(),
 998            language_servers: project.language_servers.clone(),
 999        })?;
1000
1001        for (worktree_id, worktree) in &project.worktrees {
1002            #[cfg(any(test, feature = "test-support"))]
1003            const MAX_CHUNK_SIZE: usize = 2;
1004            #[cfg(not(any(test, feature = "test-support")))]
1005            const MAX_CHUNK_SIZE: usize = 256;
1006
1007            // Stream this worktree's entries.
1008            let message = proto::UpdateWorktree {
1009                project_id: project_id.to_proto(),
1010                worktree_id: *worktree_id,
1011                abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
1012                root_name: worktree.root_name.clone(),
1013                updated_entries: worktree.entries.values().cloned().collect(),
1014                removed_entries: Default::default(),
1015                scan_id: worktree.scan_id,
1016                is_last_update: worktree.is_complete,
1017            };
1018            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1019                self.peer
1020                    .send(request.sender_connection_id, update.clone())?;
1021            }
1022
1023            // Stream this worktree's diagnostics.
1024            for summary in worktree.diagnostic_summaries.values() {
1025                self.peer.send(
1026                    request.sender_connection_id,
1027                    proto::UpdateDiagnosticSummary {
1028                        project_id: project_id.to_proto(),
1029                        worktree_id: *worktree_id,
1030                        summary: Some(summary.clone()),
1031                    },
1032                )?;
1033            }
1034        }
1035
1036        for language_server in &project.language_servers {
1037            self.peer.send(
1038                request.sender_connection_id,
1039                proto::UpdateLanguageServer {
1040                    project_id: project_id.to_proto(),
1041                    language_server_id: language_server.id,
1042                    variant: Some(
1043                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1044                            proto::LspDiskBasedDiagnosticsUpdated {},
1045                        ),
1046                    ),
1047                },
1048            )?;
1049        }
1050
1051        Ok(())
1052    }
1053
1054    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1055        let sender_id = request.sender_connection_id;
1056        let project_id = ProjectId::from_proto(request.payload.project_id);
1057        let project;
1058        {
1059            let mut store = self.store().await;
1060            project = store.leave_project(project_id, sender_id)?;
1061            tracing::info!(
1062                %project_id,
1063                host_user_id = %project.host_user_id,
1064                host_connection_id = %project.host_connection_id,
1065                "leave project"
1066            );
1067
1068            if project.remove_collaborator {
1069                broadcast(sender_id, project.connection_ids, |conn_id| {
1070                    self.peer.send(
1071                        conn_id,
1072                        proto::RemoveProjectCollaborator {
1073                            project_id: project_id.to_proto(),
1074                            peer_id: sender_id.0,
1075                        },
1076                    )
1077                });
1078            }
1079        }
1080
1081        Ok(())
1082    }
1083
1084    async fn update_project(
1085        self: Arc<Server>,
1086        request: Message<proto::UpdateProject>,
1087        response: Response<proto::UpdateProject>,
1088    ) -> Result<()> {
1089        let project_id = ProjectId::from_proto(request.payload.project_id);
1090        let (room, guest_connection_ids) = self
1091            .app_state
1092            .db
1093            .update_project(
1094                project_id,
1095                request.sender_connection_id,
1096                &request.payload.worktrees,
1097            )
1098            .await?;
1099        broadcast(
1100            request.sender_connection_id,
1101            guest_connection_ids,
1102            |connection_id| {
1103                self.peer.send(
1104                    connection_id,
1105                    proto::ProjectUpdated {
1106                        project_id: project_id.to_proto(),
1107                        worktrees: request.payload.worktrees.clone(),
1108                        room_version: room.version,
1109                    },
1110                )
1111            },
1112        );
1113        self.room_updated(&room);
1114        response.send(proto::Ack {})?;
1115
1116        Ok(())
1117    }
1118
1119    async fn update_worktree(
1120        self: Arc<Server>,
1121        request: Message<proto::UpdateWorktree>,
1122        response: Response<proto::UpdateWorktree>,
1123    ) -> Result<()> {
1124        let project_id = ProjectId::from_proto(request.payload.project_id);
1125        let worktree_id = request.payload.worktree_id;
1126        let connection_ids = self.store().await.update_worktree(
1127            request.sender_connection_id,
1128            project_id,
1129            worktree_id,
1130            &request.payload.root_name,
1131            &request.payload.removed_entries,
1132            &request.payload.updated_entries,
1133            request.payload.scan_id,
1134            request.payload.is_last_update,
1135        )?;
1136
1137        broadcast(
1138            request.sender_connection_id,
1139            connection_ids,
1140            |connection_id| {
1141                self.peer.forward_send(
1142                    request.sender_connection_id,
1143                    connection_id,
1144                    request.payload.clone(),
1145                )
1146            },
1147        );
1148        response.send(proto::Ack {})?;
1149        Ok(())
1150    }
1151
1152    async fn update_diagnostic_summary(
1153        self: Arc<Server>,
1154        request: Message<proto::UpdateDiagnosticSummary>,
1155    ) -> Result<()> {
1156        let summary = request
1157            .payload
1158            .summary
1159            .clone()
1160            .ok_or_else(|| anyhow!("invalid summary"))?;
1161        let receiver_ids = self.store().await.update_diagnostic_summary(
1162            ProjectId::from_proto(request.payload.project_id),
1163            request.payload.worktree_id,
1164            request.sender_connection_id,
1165            summary,
1166        )?;
1167
1168        broadcast(
1169            request.sender_connection_id,
1170            receiver_ids,
1171            |connection_id| {
1172                self.peer.forward_send(
1173                    request.sender_connection_id,
1174                    connection_id,
1175                    request.payload.clone(),
1176                )
1177            },
1178        );
1179        Ok(())
1180    }
1181
1182    async fn start_language_server(
1183        self: Arc<Server>,
1184        request: Message<proto::StartLanguageServer>,
1185    ) -> Result<()> {
1186        let receiver_ids = self.store().await.start_language_server(
1187            ProjectId::from_proto(request.payload.project_id),
1188            request.sender_connection_id,
1189            request
1190                .payload
1191                .server
1192                .clone()
1193                .ok_or_else(|| anyhow!("invalid language server"))?,
1194        )?;
1195        broadcast(
1196            request.sender_connection_id,
1197            receiver_ids,
1198            |connection_id| {
1199                self.peer.forward_send(
1200                    request.sender_connection_id,
1201                    connection_id,
1202                    request.payload.clone(),
1203                )
1204            },
1205        );
1206        Ok(())
1207    }
1208
1209    async fn update_language_server(
1210        self: Arc<Server>,
1211        request: Message<proto::UpdateLanguageServer>,
1212    ) -> Result<()> {
1213        let receiver_ids = self.store().await.project_connection_ids(
1214            ProjectId::from_proto(request.payload.project_id),
1215            request.sender_connection_id,
1216        )?;
1217        broadcast(
1218            request.sender_connection_id,
1219            receiver_ids,
1220            |connection_id| {
1221                self.peer.forward_send(
1222                    request.sender_connection_id,
1223                    connection_id,
1224                    request.payload.clone(),
1225                )
1226            },
1227        );
1228        Ok(())
1229    }
1230
1231    async fn forward_project_request<T>(
1232        self: Arc<Server>,
1233        request: Message<T>,
1234        response: Response<T>,
1235    ) -> Result<()>
1236    where
1237        T: EntityMessage + RequestMessage,
1238    {
1239        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1240        let host_connection_id = self
1241            .store()
1242            .await
1243            .read_project(project_id, request.sender_connection_id)?
1244            .host_connection_id;
1245        let payload = self
1246            .peer
1247            .forward_request(
1248                request.sender_connection_id,
1249                host_connection_id,
1250                request.payload,
1251            )
1252            .await?;
1253
1254        // Ensure project still exists by the time we get the response from the host.
1255        self.store()
1256            .await
1257            .read_project(project_id, request.sender_connection_id)?;
1258
1259        response.send(payload)?;
1260        Ok(())
1261    }
1262
1263    async fn save_buffer(
1264        self: Arc<Server>,
1265        request: Message<proto::SaveBuffer>,
1266        response: Response<proto::SaveBuffer>,
1267    ) -> Result<()> {
1268        let project_id = ProjectId::from_proto(request.payload.project_id);
1269        let host = self
1270            .store()
1271            .await
1272            .read_project(project_id, request.sender_connection_id)?
1273            .host_connection_id;
1274        let response_payload = self
1275            .peer
1276            .forward_request(request.sender_connection_id, host, request.payload.clone())
1277            .await?;
1278
1279        let mut guests = self
1280            .store()
1281            .await
1282            .read_project(project_id, request.sender_connection_id)?
1283            .connection_ids();
1284        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
1285        broadcast(host, guests, |conn_id| {
1286            self.peer
1287                .forward_send(host, conn_id, response_payload.clone())
1288        });
1289        response.send(response_payload)?;
1290        Ok(())
1291    }
1292
1293    async fn create_buffer_for_peer(
1294        self: Arc<Server>,
1295        request: Message<proto::CreateBufferForPeer>,
1296    ) -> Result<()> {
1297        self.peer.forward_send(
1298            request.sender_connection_id,
1299            ConnectionId(request.payload.peer_id),
1300            request.payload,
1301        )?;
1302        Ok(())
1303    }
1304
1305    async fn update_buffer(
1306        self: Arc<Server>,
1307        request: Message<proto::UpdateBuffer>,
1308        response: Response<proto::UpdateBuffer>,
1309    ) -> Result<()> {
1310        let project_id = ProjectId::from_proto(request.payload.project_id);
1311        let receiver_ids = {
1312            let store = self.store().await;
1313            store.project_connection_ids(project_id, request.sender_connection_id)?
1314        };
1315
1316        broadcast(
1317            request.sender_connection_id,
1318            receiver_ids,
1319            |connection_id| {
1320                self.peer.forward_send(
1321                    request.sender_connection_id,
1322                    connection_id,
1323                    request.payload.clone(),
1324                )
1325            },
1326        );
1327        response.send(proto::Ack {})?;
1328        Ok(())
1329    }
1330
1331    async fn update_buffer_file(
1332        self: Arc<Server>,
1333        request: Message<proto::UpdateBufferFile>,
1334    ) -> Result<()> {
1335        let receiver_ids = self.store().await.project_connection_ids(
1336            ProjectId::from_proto(request.payload.project_id),
1337            request.sender_connection_id,
1338        )?;
1339        broadcast(
1340            request.sender_connection_id,
1341            receiver_ids,
1342            |connection_id| {
1343                self.peer.forward_send(
1344                    request.sender_connection_id,
1345                    connection_id,
1346                    request.payload.clone(),
1347                )
1348            },
1349        );
1350        Ok(())
1351    }
1352
1353    async fn buffer_reloaded(
1354        self: Arc<Server>,
1355        request: Message<proto::BufferReloaded>,
1356    ) -> Result<()> {
1357        let receiver_ids = self.store().await.project_connection_ids(
1358            ProjectId::from_proto(request.payload.project_id),
1359            request.sender_connection_id,
1360        )?;
1361        broadcast(
1362            request.sender_connection_id,
1363            receiver_ids,
1364            |connection_id| {
1365                self.peer.forward_send(
1366                    request.sender_connection_id,
1367                    connection_id,
1368                    request.payload.clone(),
1369                )
1370            },
1371        );
1372        Ok(())
1373    }
1374
1375    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1376        let receiver_ids = self.store().await.project_connection_ids(
1377            ProjectId::from_proto(request.payload.project_id),
1378            request.sender_connection_id,
1379        )?;
1380        broadcast(
1381            request.sender_connection_id,
1382            receiver_ids,
1383            |connection_id| {
1384                self.peer.forward_send(
1385                    request.sender_connection_id,
1386                    connection_id,
1387                    request.payload.clone(),
1388                )
1389            },
1390        );
1391        Ok(())
1392    }
1393
1394    async fn follow(
1395        self: Arc<Self>,
1396        request: Message<proto::Follow>,
1397        response: Response<proto::Follow>,
1398    ) -> Result<()> {
1399        let project_id = ProjectId::from_proto(request.payload.project_id);
1400        let leader_id = ConnectionId(request.payload.leader_id);
1401        let follower_id = request.sender_connection_id;
1402        {
1403            let store = self.store().await;
1404            if !store
1405                .project_connection_ids(project_id, follower_id)?
1406                .contains(&leader_id)
1407            {
1408                Err(anyhow!("no such peer"))?;
1409            }
1410        }
1411
1412        let mut response_payload = self
1413            .peer
1414            .forward_request(request.sender_connection_id, leader_id, request.payload)
1415            .await?;
1416        response_payload
1417            .views
1418            .retain(|view| view.leader_id != Some(follower_id.0));
1419        response.send(response_payload)?;
1420        Ok(())
1421    }
1422
1423    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1424        let project_id = ProjectId::from_proto(request.payload.project_id);
1425        let leader_id = ConnectionId(request.payload.leader_id);
1426        let store = self.store().await;
1427        if !store
1428            .project_connection_ids(project_id, request.sender_connection_id)?
1429            .contains(&leader_id)
1430        {
1431            Err(anyhow!("no such peer"))?;
1432        }
1433        self.peer
1434            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1435        Ok(())
1436    }
1437
1438    async fn update_followers(
1439        self: Arc<Self>,
1440        request: Message<proto::UpdateFollowers>,
1441    ) -> Result<()> {
1442        let project_id = ProjectId::from_proto(request.payload.project_id);
1443        let store = self.store().await;
1444        let connection_ids =
1445            store.project_connection_ids(project_id, request.sender_connection_id)?;
1446        let leader_id = request
1447            .payload
1448            .variant
1449            .as_ref()
1450            .and_then(|variant| match variant {
1451                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1452                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1453                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1454            });
1455        for follower_id in &request.payload.follower_ids {
1456            let follower_id = ConnectionId(*follower_id);
1457            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1458                self.peer.forward_send(
1459                    request.sender_connection_id,
1460                    follower_id,
1461                    request.payload.clone(),
1462                )?;
1463            }
1464        }
1465        Ok(())
1466    }
1467
1468    async fn get_users(
1469        self: Arc<Server>,
1470        request: Message<proto::GetUsers>,
1471        response: Response<proto::GetUsers>,
1472    ) -> Result<()> {
1473        let user_ids = request
1474            .payload
1475            .user_ids
1476            .into_iter()
1477            .map(UserId::from_proto)
1478            .collect();
1479        let users = self
1480            .app_state
1481            .db
1482            .get_users_by_ids(user_ids)
1483            .await?
1484            .into_iter()
1485            .map(|user| proto::User {
1486                id: user.id.to_proto(),
1487                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1488                github_login: user.github_login,
1489            })
1490            .collect();
1491        response.send(proto::UsersResponse { users })?;
1492        Ok(())
1493    }
1494
1495    async fn fuzzy_search_users(
1496        self: Arc<Server>,
1497        request: Message<proto::FuzzySearchUsers>,
1498        response: Response<proto::FuzzySearchUsers>,
1499    ) -> Result<()> {
1500        let query = request.payload.query;
1501        let db = &self.app_state.db;
1502        let users = match query.len() {
1503            0 => vec![],
1504            1 | 2 => db
1505                .get_user_by_github_account(&query, None)
1506                .await?
1507                .into_iter()
1508                .collect(),
1509            _ => db.fuzzy_search_users(&query, 10).await?,
1510        };
1511        let users = users
1512            .into_iter()
1513            .filter(|user| user.id != request.sender_user_id)
1514            .map(|user| proto::User {
1515                id: user.id.to_proto(),
1516                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1517                github_login: user.github_login,
1518            })
1519            .collect();
1520        response.send(proto::UsersResponse { users })?;
1521        Ok(())
1522    }
1523
1524    async fn request_contact(
1525        self: Arc<Server>,
1526        request: Message<proto::RequestContact>,
1527        response: Response<proto::RequestContact>,
1528    ) -> Result<()> {
1529        let requester_id = request.sender_user_id;
1530        let responder_id = UserId::from_proto(request.payload.responder_id);
1531        if requester_id == responder_id {
1532            return Err(anyhow!("cannot add yourself as a contact"))?;
1533        }
1534
1535        self.app_state
1536            .db
1537            .send_contact_request(requester_id, responder_id)
1538            .await?;
1539
1540        // Update outgoing contact requests of requester
1541        let mut update = proto::UpdateContacts::default();
1542        update.outgoing_requests.push(responder_id.to_proto());
1543        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1544            self.peer.send(connection_id, update.clone())?;
1545        }
1546
1547        // Update incoming contact requests of responder
1548        let mut update = proto::UpdateContacts::default();
1549        update
1550            .incoming_requests
1551            .push(proto::IncomingContactRequest {
1552                requester_id: requester_id.to_proto(),
1553                should_notify: true,
1554            });
1555        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1556            self.peer.send(connection_id, update.clone())?;
1557        }
1558
1559        response.send(proto::Ack {})?;
1560        Ok(())
1561    }
1562
1563    async fn respond_to_contact_request(
1564        self: Arc<Server>,
1565        request: Message<proto::RespondToContactRequest>,
1566        response: Response<proto::RespondToContactRequest>,
1567    ) -> Result<()> {
1568        let responder_id = request.sender_user_id;
1569        let requester_id = UserId::from_proto(request.payload.requester_id);
1570        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1571            self.app_state
1572                .db
1573                .dismiss_contact_notification(responder_id, requester_id)
1574                .await?;
1575        } else {
1576            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1577            self.app_state
1578                .db
1579                .respond_to_contact_request(responder_id, requester_id, accept)
1580                .await?;
1581
1582            let store = self.store().await;
1583            // Update responder with new contact
1584            let mut update = proto::UpdateContacts::default();
1585            if accept {
1586                update
1587                    .contacts
1588                    .push(store.contact_for_user(requester_id, false));
1589            }
1590            update
1591                .remove_incoming_requests
1592                .push(requester_id.to_proto());
1593            for connection_id in store.connection_ids_for_user(responder_id) {
1594                self.peer.send(connection_id, update.clone())?;
1595            }
1596
1597            // Update requester with new contact
1598            let mut update = proto::UpdateContacts::default();
1599            if accept {
1600                update
1601                    .contacts
1602                    .push(store.contact_for_user(responder_id, true));
1603            }
1604            update
1605                .remove_outgoing_requests
1606                .push(responder_id.to_proto());
1607            for connection_id in store.connection_ids_for_user(requester_id) {
1608                self.peer.send(connection_id, update.clone())?;
1609            }
1610        }
1611
1612        response.send(proto::Ack {})?;
1613        Ok(())
1614    }
1615
1616    async fn remove_contact(
1617        self: Arc<Server>,
1618        request: Message<proto::RemoveContact>,
1619        response: Response<proto::RemoveContact>,
1620    ) -> Result<()> {
1621        let requester_id = request.sender_user_id;
1622        let responder_id = UserId::from_proto(request.payload.user_id);
1623        self.app_state
1624            .db
1625            .remove_contact(requester_id, responder_id)
1626            .await?;
1627
1628        // Update outgoing contact requests of requester
1629        let mut update = proto::UpdateContacts::default();
1630        update
1631            .remove_outgoing_requests
1632            .push(responder_id.to_proto());
1633        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1634            self.peer.send(connection_id, update.clone())?;
1635        }
1636
1637        // Update incoming contact requests of responder
1638        let mut update = proto::UpdateContacts::default();
1639        update
1640            .remove_incoming_requests
1641            .push(requester_id.to_proto());
1642        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1643            self.peer.send(connection_id, update.clone())?;
1644        }
1645
1646        response.send(proto::Ack {})?;
1647        Ok(())
1648    }
1649
1650    async fn update_diff_base(
1651        self: Arc<Server>,
1652        request: Message<proto::UpdateDiffBase>,
1653    ) -> Result<()> {
1654        let receiver_ids = self.store().await.project_connection_ids(
1655            ProjectId::from_proto(request.payload.project_id),
1656            request.sender_connection_id,
1657        )?;
1658        broadcast(
1659            request.sender_connection_id,
1660            receiver_ids,
1661            |connection_id| {
1662                self.peer.forward_send(
1663                    request.sender_connection_id,
1664                    connection_id,
1665                    request.payload.clone(),
1666                )
1667            },
1668        );
1669        Ok(())
1670    }
1671
1672    async fn get_private_user_info(
1673        self: Arc<Self>,
1674        request: Message<proto::GetPrivateUserInfo>,
1675        response: Response<proto::GetPrivateUserInfo>,
1676    ) -> Result<()> {
1677        let metrics_id = self
1678            .app_state
1679            .db
1680            .get_user_metrics_id(request.sender_user_id)
1681            .await?;
1682        let user = self
1683            .app_state
1684            .db
1685            .get_user_by_id(request.sender_user_id)
1686            .await?
1687            .ok_or_else(|| anyhow!("user not found"))?;
1688        response.send(proto::GetPrivateUserInfoResponse {
1689            metrics_id,
1690            staff: user.admin,
1691        })?;
1692        Ok(())
1693    }
1694
1695    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1696        #[cfg(test)]
1697        tokio::task::yield_now().await;
1698        let guard = self.store.lock().await;
1699        #[cfg(test)]
1700        tokio::task::yield_now().await;
1701        StoreGuard {
1702            guard,
1703            _not_send: PhantomData,
1704        }
1705    }
1706
1707    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1708        ServerSnapshot {
1709            store: self.store().await,
1710            peer: &self.peer,
1711        }
1712    }
1713}
1714
1715impl<'a> Deref for StoreGuard<'a> {
1716    type Target = Store;
1717
1718    fn deref(&self) -> &Self::Target {
1719        &*self.guard
1720    }
1721}
1722
1723impl<'a> DerefMut for StoreGuard<'a> {
1724    fn deref_mut(&mut self) -> &mut Self::Target {
1725        &mut *self.guard
1726    }
1727}
1728
1729impl<'a> Drop for StoreGuard<'a> {
1730    fn drop(&mut self) {
1731        #[cfg(test)]
1732        self.check_invariants();
1733    }
1734}
1735
1736impl Executor for RealExecutor {
1737    type Sleep = Sleep;
1738
1739    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1740        tokio::task::spawn(future);
1741    }
1742
1743    fn sleep(&self, duration: Duration) -> Self::Sleep {
1744        tokio::time::sleep(duration)
1745    }
1746}
1747
1748fn broadcast<F>(
1749    sender_id: ConnectionId,
1750    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1751    mut f: F,
1752) where
1753    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1754{
1755    for receiver_id in receiver_ids {
1756        if receiver_id != sender_id {
1757            f(receiver_id).trace_err();
1758        }
1759    }
1760}
1761
1762lazy_static! {
1763    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1764}
1765
1766pub struct ProtocolVersion(u32);
1767
1768impl Header for ProtocolVersion {
1769    fn name() -> &'static HeaderName {
1770        &ZED_PROTOCOL_VERSION
1771    }
1772
1773    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1774    where
1775        Self: Sized,
1776        I: Iterator<Item = &'i axum::http::HeaderValue>,
1777    {
1778        let version = values
1779            .next()
1780            .ok_or_else(axum::headers::Error::invalid)?
1781            .to_str()
1782            .map_err(|_| axum::headers::Error::invalid())?
1783            .parse()
1784            .map_err(|_| axum::headers::Error::invalid())?;
1785        Ok(Self(version))
1786    }
1787
1788    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1789        values.extend([self.0.to_string().parse().unwrap()]);
1790    }
1791}
1792
1793pub fn routes(server: Arc<Server>) -> Router<Body> {
1794    Router::new()
1795        .route("/rpc", get(handle_websocket_request))
1796        .layer(
1797            ServiceBuilder::new()
1798                .layer(Extension(server.app_state.clone()))
1799                .layer(middleware::from_fn(auth::validate_header)),
1800        )
1801        .route("/metrics", get(handle_metrics))
1802        .layer(Extension(server))
1803}
1804
1805pub async fn handle_websocket_request(
1806    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1807    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1808    Extension(server): Extension<Arc<Server>>,
1809    Extension(user): Extension<User>,
1810    ws: WebSocketUpgrade,
1811) -> axum::response::Response {
1812    if protocol_version != rpc::PROTOCOL_VERSION {
1813        return (
1814            StatusCode::UPGRADE_REQUIRED,
1815            "client must be upgraded".to_string(),
1816        )
1817            .into_response();
1818    }
1819    let socket_address = socket_address.to_string();
1820    ws.on_upgrade(move |socket| {
1821        use util::ResultExt;
1822        let socket = socket
1823            .map_ok(to_tungstenite_message)
1824            .err_into()
1825            .with(|message| async move { Ok(to_axum_message(message)) });
1826        let connection = Connection::new(Box::pin(socket));
1827        async move {
1828            server
1829                .handle_connection(connection, socket_address, user, None, RealExecutor)
1830                .await
1831                .log_err();
1832        }
1833    })
1834}
1835
1836pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1837    let metrics = server.store().await.metrics();
1838    METRIC_CONNECTIONS.set(metrics.connections as _);
1839    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1840
1841    let encoder = prometheus::TextEncoder::new();
1842    let metric_families = prometheus::gather();
1843    match encoder.encode_to_string(&metric_families) {
1844        Ok(string) => (StatusCode::OK, string).into_response(),
1845        Err(error) => (
1846            StatusCode::INTERNAL_SERVER_ERROR,
1847            format!("failed to encode metrics {:?}", error),
1848        )
1849            .into_response(),
1850    }
1851}
1852
1853fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1854    match message {
1855        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1856        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1857        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1858        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1859        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1860            code: frame.code.into(),
1861            reason: frame.reason,
1862        })),
1863    }
1864}
1865
1866fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1867    match message {
1868        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1869        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1870        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1871        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1872        AxumMessage::Close(frame) => {
1873            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1874                code: frame.code.into(),
1875                reason: frame.reason,
1876            }))
1877        }
1878    }
1879}
1880
1881pub trait ResultExt {
1882    type Ok;
1883
1884    fn trace_err(self) -> Option<Self::Ok>;
1885}
1886
1887impl<T, E> ResultExt for Result<T, E>
1888where
1889    E: std::fmt::Debug,
1890{
1891    type Ok = T;
1892
1893    fn trace_err(self) -> Option<T> {
1894        match self {
1895            Ok(value) => Some(value),
1896            Err(error) => {
1897                tracing::error!("{:?}", error);
1898                None
1899            }
1900        }
1901    }
1902}