rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52pub use store::{Store, Worktree};
  53use tokio::{
  54    sync::{Mutex, MutexGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60lazy_static! {
  61    static ref METRIC_CONNECTIONS: IntGauge =
  62        register_int_gauge!("connections", "number of connections").unwrap();
  63    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  64        "shared_projects",
  65        "number of open projects with one or more guests"
  66    )
  67    .unwrap();
  68}
  69
  70type MessageHandler = Box<
  71    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  72>;
  73
  74struct Message<T> {
  75    sender_user_id: UserId,
  76    sender_connection_id: ConnectionId,
  77    payload: T,
  78}
  79
  80struct Response<R> {
  81    server: Arc<Server>,
  82    receipt: Receipt<R>,
  83    responded: Arc<AtomicBool>,
  84}
  85
  86impl<R: RequestMessage> Response<R> {
  87    fn send(self, payload: R::Response) -> Result<()> {
  88        self.responded.store(true, SeqCst);
  89        self.server.peer.respond(self.receipt, payload)?;
  90        Ok(())
  91    }
  92}
  93
  94pub struct Server {
  95    peer: Arc<Peer>,
  96    pub(crate) store: Mutex<Store>,
  97    app_state: Arc<AppState>,
  98    handlers: HashMap<TypeId, MessageHandler>,
  99}
 100
 101pub trait Executor: Send + Clone {
 102    type Sleep: Send + Future;
 103    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 104    fn sleep(&self, duration: Duration) -> Self::Sleep;
 105}
 106
 107#[derive(Clone)]
 108pub struct RealExecutor;
 109
 110pub(crate) struct StoreGuard<'a> {
 111    guard: MutexGuard<'a, Store>,
 112    _not_send: PhantomData<Rc<()>>,
 113}
 114
 115#[derive(Serialize)]
 116pub struct ServerSnapshot<'a> {
 117    peer: &'a Peer,
 118    #[serde(serialize_with = "serialize_deref")]
 119    store: StoreGuard<'a>,
 120}
 121
 122pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 123where
 124    S: Serializer,
 125    T: Deref<Target = U>,
 126    U: Serialize,
 127{
 128    Serialize::serialize(value.deref(), serializer)
 129}
 130
 131impl Server {
 132    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138        };
 139
 140        server
 141            .add_request_handler(Server::ping)
 142            .add_request_handler(Server::create_room)
 143            .add_request_handler(Server::join_room)
 144            .add_message_handler(Server::leave_room)
 145            .add_request_handler(Server::call)
 146            .add_request_handler(Server::cancel_call)
 147            .add_message_handler(Server::decline_call)
 148            .add_request_handler(Server::update_participant_location)
 149            .add_request_handler(Server::share_project)
 150            .add_message_handler(Server::unshare_project)
 151            .add_request_handler(Server::join_project)
 152            .add_message_handler(Server::leave_project)
 153            .add_request_handler(Server::update_project)
 154            .add_request_handler(Server::update_worktree)
 155            .add_message_handler(Server::start_language_server)
 156            .add_message_handler(Server::update_language_server)
 157            .add_request_handler(Server::update_diagnostic_summary)
 158            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 159            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 162            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 163            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 165            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 169            .add_request_handler(
 170                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 171            )
 172            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 173            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 174            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 175            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 177            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 179            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 182            .add_message_handler(Server::create_buffer_for_peer)
 183            .add_request_handler(Server::update_buffer)
 184            .add_message_handler(Server::update_buffer_file)
 185            .add_message_handler(Server::buffer_reloaded)
 186            .add_message_handler(Server::buffer_saved)
 187            .add_request_handler(Server::save_buffer)
 188            .add_request_handler(Server::get_users)
 189            .add_request_handler(Server::fuzzy_search_users)
 190            .add_request_handler(Server::request_contact)
 191            .add_request_handler(Server::remove_contact)
 192            .add_request_handler(Server::respond_to_contact_request)
 193            .add_request_handler(Server::follow)
 194            .add_message_handler(Server::unfollow)
 195            .add_request_handler(Server::update_followers)
 196            .add_message_handler(Server::update_diff_base)
 197            .add_request_handler(Server::get_private_user_info);
 198
 199        Arc::new(server)
 200    }
 201
 202    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 203    where
 204        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 205        Fut: 'static + Send + Future<Output = Result<()>>,
 206        M: EnvelopedMessage,
 207    {
 208        let prev_handler = self.handlers.insert(
 209            TypeId::of::<M>(),
 210            Box::new(move |server, sender_user_id, envelope| {
 211                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 212                let span = info_span!(
 213                    "handle message",
 214                    payload_type = envelope.payload_type_name()
 215                );
 216                span.in_scope(|| {
 217                    tracing::info!(
 218                        payload_type = envelope.payload_type_name(),
 219                        "message received"
 220                    );
 221                });
 222                let future = (handler)(server, sender_user_id, *envelope);
 223                async move {
 224                    if let Err(error) = future.await {
 225                        tracing::error!(%error, "error handling message");
 226                    }
 227                }
 228                .instrument(span)
 229                .boxed()
 230            }),
 231        );
 232        if prev_handler.is_some() {
 233            panic!("registered a handler for the same message twice");
 234        }
 235        self
 236    }
 237
 238    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 239    where
 240        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 241        Fut: 'static + Send + Future<Output = Result<()>>,
 242        M: EnvelopedMessage,
 243    {
 244        self.add_handler(move |server, sender_user_id, envelope| {
 245            handler(
 246                server,
 247                Message {
 248                    sender_user_id,
 249                    sender_connection_id: envelope.sender_id,
 250                    payload: envelope.payload,
 251                },
 252            )
 253        });
 254        self
 255    }
 256
 257    /// Handle a request while holding a lock to the store. This is useful when we're registering
 258    /// a connection but we want to respond on the connection before anybody else can send on it.
 259    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 260    where
 261        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 262        Fut: Send + Future<Output = Result<()>>,
 263        M: RequestMessage,
 264    {
 265        let handler = Arc::new(handler);
 266        self.add_handler(move |server, sender_user_id, envelope| {
 267            let receipt = envelope.receipt();
 268            let handler = handler.clone();
 269            async move {
 270                let request = Message {
 271                    sender_user_id,
 272                    sender_connection_id: envelope.sender_id,
 273                    payload: envelope.payload,
 274                };
 275                let responded = Arc::new(AtomicBool::default());
 276                let response = Response {
 277                    server: server.clone(),
 278                    responded: responded.clone(),
 279                    receipt,
 280                };
 281                match (handler)(server.clone(), request, response).await {
 282                    Ok(()) => {
 283                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 284                            Ok(())
 285                        } else {
 286                            Err(anyhow!("handler did not send a response"))?
 287                        }
 288                    }
 289                    Err(error) => {
 290                        server.peer.respond_with_error(
 291                            receipt,
 292                            proto::Error {
 293                                message: error.to_string(),
 294                            },
 295                        )?;
 296                        Err(error)
 297                    }
 298                }
 299            }
 300        })
 301    }
 302
 303    pub fn handle_connection<E: Executor>(
 304        self: &Arc<Self>,
 305        connection: Connection,
 306        address: String,
 307        user: User,
 308        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 309        executor: E,
 310    ) -> impl Future<Output = Result<()>> {
 311        let mut this = self.clone();
 312        let user_id = user.id;
 313        let login = user.github_login;
 314        let span = info_span!("handle connection", %user_id, %login, %address);
 315        async move {
 316            let (connection_id, handle_io, mut incoming_rx) = this
 317                .peer
 318                .add_connection(connection, {
 319                    let executor = executor.clone();
 320                    move |duration| {
 321                        let timer = executor.sleep(duration);
 322                        async move {
 323                            timer.await;
 324                        }
 325                    }
 326                });
 327
 328            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 329            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 330            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 331
 332            if let Some(send_connection_id) = send_connection_id.take() {
 333                let _ = send_connection_id.send(connection_id);
 334            }
 335
 336            if !user.connected_once {
 337                this.peer.send(connection_id, proto::ShowContacts {})?;
 338                this.app_state.db.set_user_connected_once(user_id, true).await?;
 339            }
 340
 341            let (contacts, invite_code) = future::try_join(
 342                this.app_state.db.get_contacts(user_id),
 343                this.app_state.db.get_invite_code_for_user(user_id)
 344            ).await?;
 345
 346            {
 347                let mut store = this.store().await;
 348                store.add_connection(connection_id, user_id, user.admin);
 349                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 350
 351                if let Some((code, count)) = invite_code {
 352                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 353                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 354                        count,
 355                    })?;
 356                }
 357            }
 358
 359            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 360                this.peer.send(connection_id, incoming_call)?;
 361            }
 362
 363            this.update_user_contacts(user_id).await?;
 364
 365            let handle_io = handle_io.fuse();
 366            futures::pin_mut!(handle_io);
 367
 368            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 369            // This prevents deadlocks when e.g., client A performs a request to client B and
 370            // client B performs a request to client A. If both clients stop processing further
 371            // messages until their respective request completes, they won't have a chance to
 372            // respond to the other client's request and cause a deadlock.
 373            //
 374            // This arrangement ensures we will attempt to process earlier messages first, but fall
 375            // back to processing messages arrived later in the spirit of making progress.
 376            let mut foreground_message_handlers = FuturesUnordered::new();
 377            loop {
 378                let next_message = incoming_rx.next().fuse();
 379                futures::pin_mut!(next_message);
 380                futures::select_biased! {
 381                    result = handle_io => {
 382                        if let Err(error) = result {
 383                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 384                        }
 385                        break;
 386                    }
 387                    _ = foreground_message_handlers.next() => {}
 388                    message = next_message => {
 389                        if let Some(message) = message {
 390                            let type_name = message.payload_type_name();
 391                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 392                            let span_enter = span.enter();
 393                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 394                                let is_background = message.is_background();
 395                                let handle_message = (handler)(this.clone(), user_id, message);
 396                                drop(span_enter);
 397
 398                                let handle_message = handle_message.instrument(span);
 399                                if is_background {
 400                                    executor.spawn_detached(handle_message);
 401                                } else {
 402                                    foreground_message_handlers.push(handle_message);
 403                                }
 404                            } else {
 405                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 406                            }
 407                        } else {
 408                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 409                            break;
 410                        }
 411                    }
 412                }
 413            }
 414
 415            drop(foreground_message_handlers);
 416            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 417            if let Err(error) = this.sign_out(connection_id, user_id).await {
 418                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 419            }
 420
 421            Ok(())
 422        }.instrument(span)
 423    }
 424
 425    #[instrument(skip(self), err)]
 426    async fn sign_out(
 427        self: &mut Arc<Self>,
 428        connection_id: ConnectionId,
 429        user_id: UserId,
 430    ) -> Result<()> {
 431        self.peer.disconnect(connection_id);
 432        let decline_calls = {
 433            let mut store = self.store().await;
 434            store.remove_connection(connection_id)?;
 435            let mut connections = store.connection_ids_for_user(user_id);
 436            connections.next().is_none()
 437        };
 438
 439        self.leave_room_for_connection(connection_id, user_id)
 440            .await
 441            .trace_err();
 442        if decline_calls {
 443            if let Some(room) = self
 444                .app_state
 445                .db
 446                .decline_call(None, user_id)
 447                .await
 448                .trace_err()
 449            {
 450                self.room_updated(&room);
 451            }
 452        }
 453
 454        self.update_user_contacts(user_id).await?;
 455
 456        Ok(())
 457    }
 458
 459    pub async fn invite_code_redeemed(
 460        self: &Arc<Self>,
 461        inviter_id: UserId,
 462        invitee_id: UserId,
 463    ) -> Result<()> {
 464        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 465            if let Some(code) = &user.invite_code {
 466                let store = self.store().await;
 467                let invitee_contact = store.contact_for_user(invitee_id, true, false);
 468                for connection_id in store.connection_ids_for_user(inviter_id) {
 469                    self.peer.send(
 470                        connection_id,
 471                        proto::UpdateContacts {
 472                            contacts: vec![invitee_contact.clone()],
 473                            ..Default::default()
 474                        },
 475                    )?;
 476                    self.peer.send(
 477                        connection_id,
 478                        proto::UpdateInviteInfo {
 479                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 480                            count: user.invite_count as u32,
 481                        },
 482                    )?;
 483                }
 484            }
 485        }
 486        Ok(())
 487    }
 488
 489    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 490        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 491            if let Some(invite_code) = &user.invite_code {
 492                let store = self.store().await;
 493                for connection_id in store.connection_ids_for_user(user_id) {
 494                    self.peer.send(
 495                        connection_id,
 496                        proto::UpdateInviteInfo {
 497                            url: format!(
 498                                "{}{}",
 499                                self.app_state.config.invite_link_prefix, invite_code
 500                            ),
 501                            count: user.invite_count as u32,
 502                        },
 503                    )?;
 504                }
 505            }
 506        }
 507        Ok(())
 508    }
 509
 510    async fn ping(
 511        self: Arc<Server>,
 512        _: Message<proto::Ping>,
 513        response: Response<proto::Ping>,
 514    ) -> Result<()> {
 515        response.send(proto::Ack {})?;
 516        Ok(())
 517    }
 518
 519    async fn create_room(
 520        self: Arc<Server>,
 521        request: Message<proto::CreateRoom>,
 522        response: Response<proto::CreateRoom>,
 523    ) -> Result<()> {
 524        let room = self
 525            .app_state
 526            .db
 527            .create_room(request.sender_user_id, request.sender_connection_id)
 528            .await?;
 529
 530        let live_kit_connection_info =
 531            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 532                if let Some(_) = live_kit
 533                    .create_room(room.live_kit_room.clone())
 534                    .await
 535                    .trace_err()
 536                {
 537                    if let Some(token) = live_kit
 538                        .room_token(
 539                            &room.live_kit_room,
 540                            &request.sender_connection_id.to_string(),
 541                        )
 542                        .trace_err()
 543                    {
 544                        Some(proto::LiveKitConnectionInfo {
 545                            server_url: live_kit.url().into(),
 546                            token,
 547                        })
 548                    } else {
 549                        None
 550                    }
 551                } else {
 552                    None
 553                }
 554            } else {
 555                None
 556            };
 557
 558        response.send(proto::CreateRoomResponse {
 559            room: Some(room),
 560            live_kit_connection_info,
 561        })?;
 562        self.update_user_contacts(request.sender_user_id).await?;
 563        Ok(())
 564    }
 565
 566    async fn join_room(
 567        self: Arc<Server>,
 568        request: Message<proto::JoinRoom>,
 569        response: Response<proto::JoinRoom>,
 570    ) -> Result<()> {
 571        let room = self
 572            .app_state
 573            .db
 574            .join_room(
 575                RoomId::from_proto(request.payload.id),
 576                request.sender_user_id,
 577                request.sender_connection_id,
 578            )
 579            .await?;
 580        for connection_id in self
 581            .store()
 582            .await
 583            .connection_ids_for_user(request.sender_user_id)
 584        {
 585            self.peer
 586                .send(connection_id, proto::CallCanceled {})
 587                .trace_err();
 588        }
 589
 590        let live_kit_connection_info =
 591            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 592                if let Some(token) = live_kit
 593                    .room_token(
 594                        &room.live_kit_room,
 595                        &request.sender_connection_id.to_string(),
 596                    )
 597                    .trace_err()
 598                {
 599                    Some(proto::LiveKitConnectionInfo {
 600                        server_url: live_kit.url().into(),
 601                        token,
 602                    })
 603                } else {
 604                    None
 605                }
 606            } else {
 607                None
 608            };
 609
 610        self.room_updated(&room);
 611        response.send(proto::JoinRoomResponse {
 612            room: Some(room),
 613            live_kit_connection_info,
 614        })?;
 615
 616        self.update_user_contacts(request.sender_user_id).await?;
 617        Ok(())
 618    }
 619
 620    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 621        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
 622            .await
 623    }
 624
 625    async fn leave_room_for_connection(
 626        self: &Arc<Server>,
 627        leaving_connection_id: ConnectionId,
 628        leaving_user_id: UserId,
 629    ) -> Result<()> {
 630        let mut contacts_to_update = HashSet::default();
 631
 632        let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else {
 633            return Err(anyhow!("no room to leave"))?;
 634        };
 635        contacts_to_update.insert(leaving_user_id);
 636
 637        for project in left_room.left_projects.into_values() {
 638            for connection_id in project.connection_ids {
 639                if project.host_user_id == leaving_user_id {
 640                    self.peer
 641                        .send(
 642                            connection_id,
 643                            proto::UnshareProject {
 644                                project_id: project.id.to_proto(),
 645                            },
 646                        )
 647                        .trace_err();
 648                } else {
 649                    self.peer
 650                        .send(
 651                            connection_id,
 652                            proto::RemoveProjectCollaborator {
 653                                project_id: project.id.to_proto(),
 654                                peer_id: leaving_connection_id.0,
 655                            },
 656                        )
 657                        .trace_err();
 658                }
 659            }
 660
 661            self.peer
 662                .send(
 663                    leaving_connection_id,
 664                    proto::UnshareProject {
 665                        project_id: project.id.to_proto(),
 666                    },
 667                )
 668                .trace_err();
 669        }
 670
 671        self.room_updated(&left_room.room);
 672        {
 673            let store = self.store().await;
 674            for canceled_user_id in left_room.canceled_calls_to_user_ids {
 675                for connection_id in store.connection_ids_for_user(canceled_user_id) {
 676                    self.peer
 677                        .send(connection_id, proto::CallCanceled {})
 678                        .trace_err();
 679                }
 680                contacts_to_update.insert(canceled_user_id);
 681            }
 682        }
 683
 684        for contact_user_id in contacts_to_update {
 685            self.update_user_contacts(contact_user_id).await?;
 686        }
 687
 688        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 689            live_kit
 690                .remove_participant(
 691                    left_room.room.live_kit_room.clone(),
 692                    leaving_connection_id.to_string(),
 693                )
 694                .await
 695                .trace_err();
 696
 697            if left_room.room.participants.is_empty() {
 698                live_kit
 699                    .delete_room(left_room.room.live_kit_room)
 700                    .await
 701                    .trace_err();
 702            }
 703        }
 704
 705        Ok(())
 706    }
 707
 708    async fn call(
 709        self: Arc<Server>,
 710        request: Message<proto::Call>,
 711        response: Response<proto::Call>,
 712    ) -> Result<()> {
 713        let room_id = RoomId::from_proto(request.payload.room_id);
 714        let calling_user_id = request.sender_user_id;
 715        let calling_connection_id = request.sender_connection_id;
 716        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 717        let initial_project_id = request
 718            .payload
 719            .initial_project_id
 720            .map(ProjectId::from_proto);
 721        if !self
 722            .app_state
 723            .db
 724            .has_contact(calling_user_id, called_user_id)
 725            .await?
 726        {
 727            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 728        }
 729
 730        let (room, incoming_call) = self
 731            .app_state
 732            .db
 733            .call(
 734                room_id,
 735                calling_user_id,
 736                calling_connection_id,
 737                called_user_id,
 738                initial_project_id,
 739            )
 740            .await?;
 741        self.room_updated(&room);
 742        self.update_user_contacts(called_user_id).await?;
 743
 744        let mut calls = self
 745            .store()
 746            .await
 747            .connection_ids_for_user(called_user_id)
 748            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 749            .collect::<FuturesUnordered<_>>();
 750
 751        while let Some(call_response) = calls.next().await {
 752            match call_response.as_ref() {
 753                Ok(_) => {
 754                    response.send(proto::Ack {})?;
 755                    return Ok(());
 756                }
 757                Err(_) => {
 758                    call_response.trace_err();
 759                }
 760            }
 761        }
 762
 763        let room = self
 764            .app_state
 765            .db
 766            .call_failed(room_id, called_user_id)
 767            .await?;
 768        self.room_updated(&room);
 769        self.update_user_contacts(called_user_id).await?;
 770
 771        Err(anyhow!("failed to ring user"))?
 772    }
 773
 774    async fn cancel_call(
 775        self: Arc<Server>,
 776        request: Message<proto::CancelCall>,
 777        response: Response<proto::CancelCall>,
 778    ) -> Result<()> {
 779        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 780        let room_id = RoomId::from_proto(request.payload.room_id);
 781        let room = self
 782            .app_state
 783            .db
 784            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
 785            .await?;
 786        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
 787            self.peer
 788                .send(connection_id, proto::CallCanceled {})
 789                .trace_err();
 790        }
 791        self.room_updated(&room);
 792        response.send(proto::Ack {})?;
 793
 794        self.update_user_contacts(called_user_id).await?;
 795        Ok(())
 796    }
 797
 798    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 799        let room_id = RoomId::from_proto(message.payload.room_id);
 800        let room = self
 801            .app_state
 802            .db
 803            .decline_call(Some(room_id), message.sender_user_id)
 804            .await?;
 805        for connection_id in self
 806            .store()
 807            .await
 808            .connection_ids_for_user(message.sender_user_id)
 809        {
 810            self.peer
 811                .send(connection_id, proto::CallCanceled {})
 812                .trace_err();
 813        }
 814        self.room_updated(&room);
 815        self.update_user_contacts(message.sender_user_id).await?;
 816        Ok(())
 817    }
 818
 819    async fn update_participant_location(
 820        self: Arc<Server>,
 821        request: Message<proto::UpdateParticipantLocation>,
 822        response: Response<proto::UpdateParticipantLocation>,
 823    ) -> Result<()> {
 824        let room_id = RoomId::from_proto(request.payload.room_id);
 825        let location = request
 826            .payload
 827            .location
 828            .ok_or_else(|| anyhow!("invalid location"))?;
 829        let room = self
 830            .app_state
 831            .db
 832            .update_room_participant_location(room_id, request.sender_connection_id, location)
 833            .await?;
 834        self.room_updated(&room);
 835        response.send(proto::Ack {})?;
 836        Ok(())
 837    }
 838
 839    fn room_updated(&self, room: &proto::Room) {
 840        for participant in &room.participants {
 841            self.peer
 842                .send(
 843                    ConnectionId(participant.peer_id),
 844                    proto::RoomUpdated {
 845                        room: Some(room.clone()),
 846                    },
 847                )
 848                .trace_err();
 849        }
 850    }
 851
 852    async fn share_project(
 853        self: Arc<Server>,
 854        request: Message<proto::ShareProject>,
 855        response: Response<proto::ShareProject>,
 856    ) -> Result<()> {
 857        let (project_id, room) = self
 858            .app_state
 859            .db
 860            .share_project(
 861                RoomId::from_proto(request.payload.room_id),
 862                request.sender_connection_id,
 863                &request.payload.worktrees,
 864            )
 865            .await?;
 866        response.send(proto::ShareProjectResponse {
 867            project_id: project_id.to_proto(),
 868        })?;
 869        self.room_updated(&room);
 870
 871        Ok(())
 872    }
 873
 874    async fn unshare_project(
 875        self: Arc<Server>,
 876        message: Message<proto::UnshareProject>,
 877    ) -> Result<()> {
 878        let project_id = ProjectId::from_proto(message.payload.project_id);
 879
 880        let (room, guest_connection_ids) = self
 881            .app_state
 882            .db
 883            .unshare_project(project_id, message.sender_connection_id)
 884            .await?;
 885
 886        broadcast(
 887            message.sender_connection_id,
 888            guest_connection_ids,
 889            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 890        );
 891        self.room_updated(&room);
 892
 893        Ok(())
 894    }
 895
 896    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 897        let contacts = self.app_state.db.get_contacts(user_id).await?;
 898        let busy = self.app_state.db.is_user_busy(user_id).await?;
 899        let store = self.store().await;
 900        let updated_contact = store.contact_for_user(user_id, false, busy);
 901        for contact in contacts {
 902            if let db::Contact::Accepted {
 903                user_id: contact_user_id,
 904                ..
 905            } = contact
 906            {
 907                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 908                    self.peer
 909                        .send(
 910                            contact_conn_id,
 911                            proto::UpdateContacts {
 912                                contacts: vec![updated_contact.clone()],
 913                                remove_contacts: Default::default(),
 914                                incoming_requests: Default::default(),
 915                                remove_incoming_requests: Default::default(),
 916                                outgoing_requests: Default::default(),
 917                                remove_outgoing_requests: Default::default(),
 918                            },
 919                        )
 920                        .trace_err();
 921                }
 922            }
 923        }
 924        Ok(())
 925    }
 926
 927    async fn join_project(
 928        self: Arc<Server>,
 929        request: Message<proto::JoinProject>,
 930        response: Response<proto::JoinProject>,
 931    ) -> Result<()> {
 932        let project_id = ProjectId::from_proto(request.payload.project_id);
 933        let guest_user_id = request.sender_user_id;
 934
 935        tracing::info!(%project_id, "join project");
 936
 937        let (project, replica_id) = self
 938            .app_state
 939            .db
 940            .join_project(project_id, request.sender_connection_id)
 941            .await?;
 942
 943        let collaborators = project
 944            .collaborators
 945            .iter()
 946            .filter(|collaborator| {
 947                collaborator.connection_id != request.sender_connection_id.0 as i32
 948            })
 949            .map(|collaborator| proto::Collaborator {
 950                peer_id: collaborator.connection_id as u32,
 951                replica_id: collaborator.replica_id.0 as u32,
 952                user_id: collaborator.user_id.to_proto(),
 953            })
 954            .collect::<Vec<_>>();
 955        let worktrees = project
 956            .worktrees
 957            .iter()
 958            .map(|(id, worktree)| proto::WorktreeMetadata {
 959                id: id.to_proto(),
 960                root_name: worktree.root_name.clone(),
 961                visible: worktree.visible,
 962                abs_path: worktree.abs_path.clone(),
 963            })
 964            .collect::<Vec<_>>();
 965
 966        for collaborator in &collaborators {
 967            self.peer
 968                .send(
 969                    ConnectionId(collaborator.peer_id),
 970                    proto::AddProjectCollaborator {
 971                        project_id: project_id.to_proto(),
 972                        collaborator: Some(proto::Collaborator {
 973                            peer_id: request.sender_connection_id.0,
 974                            replica_id: replica_id.0 as u32,
 975                            user_id: guest_user_id.to_proto(),
 976                        }),
 977                    },
 978                )
 979                .trace_err();
 980        }
 981
 982        // First, we send the metadata associated with each worktree.
 983        response.send(proto::JoinProjectResponse {
 984            worktrees: worktrees.clone(),
 985            replica_id: replica_id.0 as u32,
 986            collaborators: collaborators.clone(),
 987            language_servers: project.language_servers.clone(),
 988        })?;
 989
 990        for (worktree_id, worktree) in project.worktrees {
 991            #[cfg(any(test, feature = "test-support"))]
 992            const MAX_CHUNK_SIZE: usize = 2;
 993            #[cfg(not(any(test, feature = "test-support")))]
 994            const MAX_CHUNK_SIZE: usize = 256;
 995
 996            // Stream this worktree's entries.
 997            let message = proto::UpdateWorktree {
 998                project_id: project_id.to_proto(),
 999                worktree_id: worktree_id.to_proto(),
1000                abs_path: worktree.abs_path.clone(),
1001                root_name: worktree.root_name,
1002                updated_entries: worktree.entries,
1003                removed_entries: Default::default(),
1004                scan_id: worktree.scan_id,
1005                is_last_update: worktree.is_complete,
1006            };
1007            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1008                self.peer
1009                    .send(request.sender_connection_id, update.clone())?;
1010            }
1011
1012            // Stream this worktree's diagnostics.
1013            for summary in worktree.diagnostic_summaries {
1014                self.peer.send(
1015                    request.sender_connection_id,
1016                    proto::UpdateDiagnosticSummary {
1017                        project_id: project_id.to_proto(),
1018                        worktree_id: worktree.id.to_proto(),
1019                        summary: Some(summary),
1020                    },
1021                )?;
1022            }
1023        }
1024
1025        for language_server in &project.language_servers {
1026            self.peer.send(
1027                request.sender_connection_id,
1028                proto::UpdateLanguageServer {
1029                    project_id: project_id.to_proto(),
1030                    language_server_id: language_server.id,
1031                    variant: Some(
1032                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1033                            proto::LspDiskBasedDiagnosticsUpdated {},
1034                        ),
1035                    ),
1036                },
1037            )?;
1038        }
1039
1040        Ok(())
1041    }
1042
1043    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1044        let sender_id = request.sender_connection_id;
1045        let project_id = ProjectId::from_proto(request.payload.project_id);
1046        let project;
1047        {
1048            project = self
1049                .app_state
1050                .db
1051                .leave_project(project_id, sender_id)
1052                .await?;
1053            tracing::info!(
1054                %project_id,
1055                host_user_id = %project.host_user_id,
1056                host_connection_id = %project.host_connection_id,
1057                "leave project"
1058            );
1059
1060            broadcast(sender_id, project.connection_ids, |conn_id| {
1061                self.peer.send(
1062                    conn_id,
1063                    proto::RemoveProjectCollaborator {
1064                        project_id: project_id.to_proto(),
1065                        peer_id: sender_id.0,
1066                    },
1067                )
1068            });
1069        }
1070
1071        Ok(())
1072    }
1073
1074    async fn update_project(
1075        self: Arc<Server>,
1076        request: Message<proto::UpdateProject>,
1077        response: Response<proto::UpdateProject>,
1078    ) -> Result<()> {
1079        let project_id = ProjectId::from_proto(request.payload.project_id);
1080        let (room, guest_connection_ids) = self
1081            .app_state
1082            .db
1083            .update_project(
1084                project_id,
1085                request.sender_connection_id,
1086                &request.payload.worktrees,
1087            )
1088            .await?;
1089        broadcast(
1090            request.sender_connection_id,
1091            guest_connection_ids,
1092            |connection_id| {
1093                self.peer.forward_send(
1094                    request.sender_connection_id,
1095                    connection_id,
1096                    request.payload.clone(),
1097                )
1098            },
1099        );
1100        self.room_updated(&room);
1101        response.send(proto::Ack {})?;
1102
1103        Ok(())
1104    }
1105
1106    async fn update_worktree(
1107        self: Arc<Server>,
1108        request: Message<proto::UpdateWorktree>,
1109        response: Response<proto::UpdateWorktree>,
1110    ) -> Result<()> {
1111        let guest_connection_ids = self
1112            .app_state
1113            .db
1114            .update_worktree(&request.payload, request.sender_connection_id)
1115            .await?;
1116
1117        broadcast(
1118            request.sender_connection_id,
1119            guest_connection_ids,
1120            |connection_id| {
1121                self.peer.forward_send(
1122                    request.sender_connection_id,
1123                    connection_id,
1124                    request.payload.clone(),
1125                )
1126            },
1127        );
1128        response.send(proto::Ack {})?;
1129        Ok(())
1130    }
1131
1132    async fn update_diagnostic_summary(
1133        self: Arc<Server>,
1134        request: Message<proto::UpdateDiagnosticSummary>,
1135        response: Response<proto::UpdateDiagnosticSummary>,
1136    ) -> Result<()> {
1137        let guest_connection_ids = self
1138            .app_state
1139            .db
1140            .update_diagnostic_summary(&request.payload, request.sender_connection_id)
1141            .await?;
1142
1143        broadcast(
1144            request.sender_connection_id,
1145            guest_connection_ids,
1146            |connection_id| {
1147                self.peer.forward_send(
1148                    request.sender_connection_id,
1149                    connection_id,
1150                    request.payload.clone(),
1151                )
1152            },
1153        );
1154
1155        response.send(proto::Ack {})?;
1156        Ok(())
1157    }
1158
1159    async fn start_language_server(
1160        self: Arc<Server>,
1161        request: Message<proto::StartLanguageServer>,
1162    ) -> Result<()> {
1163        let guest_connection_ids = self
1164            .app_state
1165            .db
1166            .start_language_server(&request.payload, request.sender_connection_id)
1167            .await?;
1168
1169        broadcast(
1170            request.sender_connection_id,
1171            guest_connection_ids,
1172            |connection_id| {
1173                self.peer.forward_send(
1174                    request.sender_connection_id,
1175                    connection_id,
1176                    request.payload.clone(),
1177                )
1178            },
1179        );
1180        Ok(())
1181    }
1182
1183    async fn update_language_server(
1184        self: Arc<Server>,
1185        request: Message<proto::UpdateLanguageServer>,
1186    ) -> Result<()> {
1187        let project_id = ProjectId::from_proto(request.payload.project_id);
1188        let project_connection_ids = self
1189            .app_state
1190            .db
1191            .project_connection_ids(project_id, request.sender_connection_id)
1192            .await?;
1193        broadcast(
1194            request.sender_connection_id,
1195            project_connection_ids,
1196            |connection_id| {
1197                self.peer.forward_send(
1198                    request.sender_connection_id,
1199                    connection_id,
1200                    request.payload.clone(),
1201                )
1202            },
1203        );
1204        Ok(())
1205    }
1206
1207    async fn forward_project_request<T>(
1208        self: Arc<Server>,
1209        request: Message<T>,
1210        response: Response<T>,
1211    ) -> Result<()>
1212    where
1213        T: EntityMessage + RequestMessage,
1214    {
1215        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1216        let collaborators = self
1217            .app_state
1218            .db
1219            .project_collaborators(project_id, request.sender_connection_id)
1220            .await?;
1221        let host = collaborators
1222            .iter()
1223            .find(|collaborator| collaborator.is_host)
1224            .ok_or_else(|| anyhow!("host not found"))?;
1225
1226        let payload = self
1227            .peer
1228            .forward_request(
1229                request.sender_connection_id,
1230                ConnectionId(host.connection_id as u32),
1231                request.payload,
1232            )
1233            .await?;
1234
1235        response.send(payload)?;
1236        Ok(())
1237    }
1238
1239    async fn save_buffer(
1240        self: Arc<Server>,
1241        request: Message<proto::SaveBuffer>,
1242        response: Response<proto::SaveBuffer>,
1243    ) -> Result<()> {
1244        let project_id = ProjectId::from_proto(request.payload.project_id);
1245        let collaborators = self
1246            .app_state
1247            .db
1248            .project_collaborators(project_id, request.sender_connection_id)
1249            .await?;
1250        let host = collaborators
1251            .into_iter()
1252            .find(|collaborator| collaborator.is_host)
1253            .ok_or_else(|| anyhow!("host not found"))?;
1254        let host_connection_id = ConnectionId(host.connection_id as u32);
1255        let response_payload = self
1256            .peer
1257            .forward_request(
1258                request.sender_connection_id,
1259                host_connection_id,
1260                request.payload.clone(),
1261            )
1262            .await?;
1263
1264        let mut collaborators = self
1265            .app_state
1266            .db
1267            .project_collaborators(project_id, request.sender_connection_id)
1268            .await?;
1269        collaborators.retain(|collaborator| {
1270            collaborator.connection_id != request.sender_connection_id.0 as i32
1271        });
1272        let project_connection_ids = collaborators
1273            .into_iter()
1274            .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1275        broadcast(host_connection_id, project_connection_ids, |conn_id| {
1276            self.peer
1277                .forward_send(host_connection_id, conn_id, response_payload.clone())
1278        });
1279        response.send(response_payload)?;
1280        Ok(())
1281    }
1282
1283    async fn create_buffer_for_peer(
1284        self: Arc<Server>,
1285        request: Message<proto::CreateBufferForPeer>,
1286    ) -> Result<()> {
1287        self.peer.forward_send(
1288            request.sender_connection_id,
1289            ConnectionId(request.payload.peer_id),
1290            request.payload,
1291        )?;
1292        Ok(())
1293    }
1294
1295    async fn update_buffer(
1296        self: Arc<Server>,
1297        request: Message<proto::UpdateBuffer>,
1298        response: Response<proto::UpdateBuffer>,
1299    ) -> Result<()> {
1300        let project_id = ProjectId::from_proto(request.payload.project_id);
1301        let project_connection_ids = self
1302            .app_state
1303            .db
1304            .project_connection_ids(project_id, request.sender_connection_id)
1305            .await?;
1306
1307        broadcast(
1308            request.sender_connection_id,
1309            project_connection_ids,
1310            |connection_id| {
1311                self.peer.forward_send(
1312                    request.sender_connection_id,
1313                    connection_id,
1314                    request.payload.clone(),
1315                )
1316            },
1317        );
1318        response.send(proto::Ack {})?;
1319        Ok(())
1320    }
1321
1322    async fn update_buffer_file(
1323        self: Arc<Server>,
1324        request: Message<proto::UpdateBufferFile>,
1325    ) -> Result<()> {
1326        let project_id = ProjectId::from_proto(request.payload.project_id);
1327        let project_connection_ids = self
1328            .app_state
1329            .db
1330            .project_connection_ids(project_id, request.sender_connection_id)
1331            .await?;
1332
1333        broadcast(
1334            request.sender_connection_id,
1335            project_connection_ids,
1336            |connection_id| {
1337                self.peer.forward_send(
1338                    request.sender_connection_id,
1339                    connection_id,
1340                    request.payload.clone(),
1341                )
1342            },
1343        );
1344        Ok(())
1345    }
1346
1347    async fn buffer_reloaded(
1348        self: Arc<Server>,
1349        request: Message<proto::BufferReloaded>,
1350    ) -> Result<()> {
1351        let project_id = ProjectId::from_proto(request.payload.project_id);
1352        let project_connection_ids = self
1353            .app_state
1354            .db
1355            .project_connection_ids(project_id, request.sender_connection_id)
1356            .await?;
1357        broadcast(
1358            request.sender_connection_id,
1359            project_connection_ids,
1360            |connection_id| {
1361                self.peer.forward_send(
1362                    request.sender_connection_id,
1363                    connection_id,
1364                    request.payload.clone(),
1365                )
1366            },
1367        );
1368        Ok(())
1369    }
1370
1371    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1372        let project_id = ProjectId::from_proto(request.payload.project_id);
1373        let project_connection_ids = self
1374            .app_state
1375            .db
1376            .project_connection_ids(project_id, request.sender_connection_id)
1377            .await?;
1378        broadcast(
1379            request.sender_connection_id,
1380            project_connection_ids,
1381            |connection_id| {
1382                self.peer.forward_send(
1383                    request.sender_connection_id,
1384                    connection_id,
1385                    request.payload.clone(),
1386                )
1387            },
1388        );
1389        Ok(())
1390    }
1391
1392    async fn follow(
1393        self: Arc<Self>,
1394        request: Message<proto::Follow>,
1395        response: Response<proto::Follow>,
1396    ) -> Result<()> {
1397        let project_id = ProjectId::from_proto(request.payload.project_id);
1398        let leader_id = ConnectionId(request.payload.leader_id);
1399        let follower_id = request.sender_connection_id;
1400        let project_connection_ids = self
1401            .app_state
1402            .db
1403            .project_connection_ids(project_id, request.sender_connection_id)
1404            .await?;
1405
1406        if !project_connection_ids.contains(&leader_id) {
1407            Err(anyhow!("no such peer"))?;
1408        }
1409
1410        let mut response_payload = self
1411            .peer
1412            .forward_request(request.sender_connection_id, leader_id, request.payload)
1413            .await?;
1414        response_payload
1415            .views
1416            .retain(|view| view.leader_id != Some(follower_id.0));
1417        response.send(response_payload)?;
1418        Ok(())
1419    }
1420
1421    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1422        let project_id = ProjectId::from_proto(request.payload.project_id);
1423        let leader_id = ConnectionId(request.payload.leader_id);
1424        let project_connection_ids = self
1425            .app_state
1426            .db
1427            .project_connection_ids(project_id, request.sender_connection_id)
1428            .await?;
1429        if !project_connection_ids.contains(&leader_id) {
1430            Err(anyhow!("no such peer"))?;
1431        }
1432        self.peer
1433            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1434        Ok(())
1435    }
1436
1437    async fn update_followers(
1438        self: Arc<Self>,
1439        request: Message<proto::UpdateFollowers>,
1440        response: Response<proto::UpdateFollowers>,
1441    ) -> Result<()> {
1442        let project_id = ProjectId::from_proto(request.payload.project_id);
1443        let project_connection_ids = self
1444            .app_state
1445            .db
1446            .project_connection_ids(project_id, request.sender_connection_id)
1447            .await?;
1448
1449        let leader_id = request
1450            .payload
1451            .variant
1452            .as_ref()
1453            .and_then(|variant| match variant {
1454                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1455                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1456                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1457            });
1458        for follower_id in &request.payload.follower_ids {
1459            let follower_id = ConnectionId(*follower_id);
1460            if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1461                self.peer.forward_send(
1462                    request.sender_connection_id,
1463                    follower_id,
1464                    request.payload.clone(),
1465                )?;
1466            }
1467        }
1468        response.send(proto::Ack {})?;
1469        Ok(())
1470    }
1471
1472    async fn get_users(
1473        self: Arc<Server>,
1474        request: Message<proto::GetUsers>,
1475        response: Response<proto::GetUsers>,
1476    ) -> Result<()> {
1477        let user_ids = request
1478            .payload
1479            .user_ids
1480            .into_iter()
1481            .map(UserId::from_proto)
1482            .collect();
1483        let users = self
1484            .app_state
1485            .db
1486            .get_users_by_ids(user_ids)
1487            .await?
1488            .into_iter()
1489            .map(|user| proto::User {
1490                id: user.id.to_proto(),
1491                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1492                github_login: user.github_login,
1493            })
1494            .collect();
1495        response.send(proto::UsersResponse { users })?;
1496        Ok(())
1497    }
1498
1499    async fn fuzzy_search_users(
1500        self: Arc<Server>,
1501        request: Message<proto::FuzzySearchUsers>,
1502        response: Response<proto::FuzzySearchUsers>,
1503    ) -> Result<()> {
1504        let query = request.payload.query;
1505        let db = &self.app_state.db;
1506        let users = match query.len() {
1507            0 => vec![],
1508            1 | 2 => db
1509                .get_user_by_github_account(&query, None)
1510                .await?
1511                .into_iter()
1512                .collect(),
1513            _ => db.fuzzy_search_users(&query, 10).await?,
1514        };
1515        let users = users
1516            .into_iter()
1517            .filter(|user| user.id != request.sender_user_id)
1518            .map(|user| proto::User {
1519                id: user.id.to_proto(),
1520                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1521                github_login: user.github_login,
1522            })
1523            .collect();
1524        response.send(proto::UsersResponse { users })?;
1525        Ok(())
1526    }
1527
1528    async fn request_contact(
1529        self: Arc<Server>,
1530        request: Message<proto::RequestContact>,
1531        response: Response<proto::RequestContact>,
1532    ) -> Result<()> {
1533        let requester_id = request.sender_user_id;
1534        let responder_id = UserId::from_proto(request.payload.responder_id);
1535        if requester_id == responder_id {
1536            return Err(anyhow!("cannot add yourself as a contact"))?;
1537        }
1538
1539        self.app_state
1540            .db
1541            .send_contact_request(requester_id, responder_id)
1542            .await?;
1543
1544        // Update outgoing contact requests of requester
1545        let mut update = proto::UpdateContacts::default();
1546        update.outgoing_requests.push(responder_id.to_proto());
1547        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1548            self.peer.send(connection_id, update.clone())?;
1549        }
1550
1551        // Update incoming contact requests of responder
1552        let mut update = proto::UpdateContacts::default();
1553        update
1554            .incoming_requests
1555            .push(proto::IncomingContactRequest {
1556                requester_id: requester_id.to_proto(),
1557                should_notify: true,
1558            });
1559        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1560            self.peer.send(connection_id, update.clone())?;
1561        }
1562
1563        response.send(proto::Ack {})?;
1564        Ok(())
1565    }
1566
1567    async fn respond_to_contact_request(
1568        self: Arc<Server>,
1569        request: Message<proto::RespondToContactRequest>,
1570        response: Response<proto::RespondToContactRequest>,
1571    ) -> Result<()> {
1572        let responder_id = request.sender_user_id;
1573        let requester_id = UserId::from_proto(request.payload.requester_id);
1574        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1575            self.app_state
1576                .db
1577                .dismiss_contact_notification(responder_id, requester_id)
1578                .await?;
1579        } else {
1580            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1581            self.app_state
1582                .db
1583                .respond_to_contact_request(responder_id, requester_id, accept)
1584                .await?;
1585            let busy = self.app_state.db.is_user_busy(requester_id).await?;
1586
1587            let store = self.store().await;
1588            // Update responder with new contact
1589            let mut update = proto::UpdateContacts::default();
1590            if accept {
1591                update
1592                    .contacts
1593                    .push(store.contact_for_user(requester_id, false, busy));
1594            }
1595            update
1596                .remove_incoming_requests
1597                .push(requester_id.to_proto());
1598            for connection_id in store.connection_ids_for_user(responder_id) {
1599                self.peer.send(connection_id, update.clone())?;
1600            }
1601
1602            // Update requester with new contact
1603            let mut update = proto::UpdateContacts::default();
1604            if accept {
1605                update
1606                    .contacts
1607                    .push(store.contact_for_user(responder_id, true, busy));
1608            }
1609            update
1610                .remove_outgoing_requests
1611                .push(responder_id.to_proto());
1612            for connection_id in store.connection_ids_for_user(requester_id) {
1613                self.peer.send(connection_id, update.clone())?;
1614            }
1615        }
1616
1617        response.send(proto::Ack {})?;
1618        Ok(())
1619    }
1620
1621    async fn remove_contact(
1622        self: Arc<Server>,
1623        request: Message<proto::RemoveContact>,
1624        response: Response<proto::RemoveContact>,
1625    ) -> Result<()> {
1626        let requester_id = request.sender_user_id;
1627        let responder_id = UserId::from_proto(request.payload.user_id);
1628        self.app_state
1629            .db
1630            .remove_contact(requester_id, responder_id)
1631            .await?;
1632
1633        // Update outgoing contact requests of requester
1634        let mut update = proto::UpdateContacts::default();
1635        update
1636            .remove_outgoing_requests
1637            .push(responder_id.to_proto());
1638        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1639            self.peer.send(connection_id, update.clone())?;
1640        }
1641
1642        // Update incoming contact requests of responder
1643        let mut update = proto::UpdateContacts::default();
1644        update
1645            .remove_incoming_requests
1646            .push(requester_id.to_proto());
1647        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1648            self.peer.send(connection_id, update.clone())?;
1649        }
1650
1651        response.send(proto::Ack {})?;
1652        Ok(())
1653    }
1654
1655    async fn update_diff_base(
1656        self: Arc<Server>,
1657        request: Message<proto::UpdateDiffBase>,
1658    ) -> Result<()> {
1659        let project_id = ProjectId::from_proto(request.payload.project_id);
1660        let project_connection_ids = self
1661            .app_state
1662            .db
1663            .project_connection_ids(project_id, request.sender_connection_id)
1664            .await?;
1665        broadcast(
1666            request.sender_connection_id,
1667            project_connection_ids,
1668            |connection_id| {
1669                self.peer.forward_send(
1670                    request.sender_connection_id,
1671                    connection_id,
1672                    request.payload.clone(),
1673                )
1674            },
1675        );
1676        Ok(())
1677    }
1678
1679    async fn get_private_user_info(
1680        self: Arc<Self>,
1681        request: Message<proto::GetPrivateUserInfo>,
1682        response: Response<proto::GetPrivateUserInfo>,
1683    ) -> Result<()> {
1684        let metrics_id = self
1685            .app_state
1686            .db
1687            .get_user_metrics_id(request.sender_user_id)
1688            .await?;
1689        let user = self
1690            .app_state
1691            .db
1692            .get_user_by_id(request.sender_user_id)
1693            .await?
1694            .ok_or_else(|| anyhow!("user not found"))?;
1695        response.send(proto::GetPrivateUserInfoResponse {
1696            metrics_id,
1697            staff: user.admin,
1698        })?;
1699        Ok(())
1700    }
1701
1702    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1703        #[cfg(test)]
1704        tokio::task::yield_now().await;
1705        let guard = self.store.lock().await;
1706        #[cfg(test)]
1707        tokio::task::yield_now().await;
1708        StoreGuard {
1709            guard,
1710            _not_send: PhantomData,
1711        }
1712    }
1713
1714    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1715        ServerSnapshot {
1716            store: self.store().await,
1717            peer: &self.peer,
1718        }
1719    }
1720}
1721
1722impl<'a> Deref for StoreGuard<'a> {
1723    type Target = Store;
1724
1725    fn deref(&self) -> &Self::Target {
1726        &*self.guard
1727    }
1728}
1729
1730impl<'a> DerefMut for StoreGuard<'a> {
1731    fn deref_mut(&mut self) -> &mut Self::Target {
1732        &mut *self.guard
1733    }
1734}
1735
1736impl<'a> Drop for StoreGuard<'a> {
1737    fn drop(&mut self) {
1738        #[cfg(test)]
1739        self.check_invariants();
1740    }
1741}
1742
1743impl Executor for RealExecutor {
1744    type Sleep = Sleep;
1745
1746    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1747        tokio::task::spawn(future);
1748    }
1749
1750    fn sleep(&self, duration: Duration) -> Self::Sleep {
1751        tokio::time::sleep(duration)
1752    }
1753}
1754
1755fn broadcast<F>(
1756    sender_id: ConnectionId,
1757    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1758    mut f: F,
1759) where
1760    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1761{
1762    for receiver_id in receiver_ids {
1763        if receiver_id != sender_id {
1764            f(receiver_id).trace_err();
1765        }
1766    }
1767}
1768
1769lazy_static! {
1770    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1771}
1772
1773pub struct ProtocolVersion(u32);
1774
1775impl Header for ProtocolVersion {
1776    fn name() -> &'static HeaderName {
1777        &ZED_PROTOCOL_VERSION
1778    }
1779
1780    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1781    where
1782        Self: Sized,
1783        I: Iterator<Item = &'i axum::http::HeaderValue>,
1784    {
1785        let version = values
1786            .next()
1787            .ok_or_else(axum::headers::Error::invalid)?
1788            .to_str()
1789            .map_err(|_| axum::headers::Error::invalid())?
1790            .parse()
1791            .map_err(|_| axum::headers::Error::invalid())?;
1792        Ok(Self(version))
1793    }
1794
1795    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1796        values.extend([self.0.to_string().parse().unwrap()]);
1797    }
1798}
1799
1800pub fn routes(server: Arc<Server>) -> Router<Body> {
1801    Router::new()
1802        .route("/rpc", get(handle_websocket_request))
1803        .layer(
1804            ServiceBuilder::new()
1805                .layer(Extension(server.app_state.clone()))
1806                .layer(middleware::from_fn(auth::validate_header)),
1807        )
1808        .route("/metrics", get(handle_metrics))
1809        .layer(Extension(server))
1810}
1811
1812pub async fn handle_websocket_request(
1813    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1814    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1815    Extension(server): Extension<Arc<Server>>,
1816    Extension(user): Extension<User>,
1817    ws: WebSocketUpgrade,
1818) -> axum::response::Response {
1819    if protocol_version != rpc::PROTOCOL_VERSION {
1820        return (
1821            StatusCode::UPGRADE_REQUIRED,
1822            "client must be upgraded".to_string(),
1823        )
1824            .into_response();
1825    }
1826    let socket_address = socket_address.to_string();
1827    ws.on_upgrade(move |socket| {
1828        use util::ResultExt;
1829        let socket = socket
1830            .map_ok(to_tungstenite_message)
1831            .err_into()
1832            .with(|message| async move { Ok(to_axum_message(message)) });
1833        let connection = Connection::new(Box::pin(socket));
1834        async move {
1835            server
1836                .handle_connection(connection, socket_address, user, None, RealExecutor)
1837                .await
1838                .log_err();
1839        }
1840    })
1841}
1842
1843pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1844    let metrics = server.store().await.metrics();
1845    METRIC_CONNECTIONS.set(metrics.connections as _);
1846    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1847
1848    let encoder = prometheus::TextEncoder::new();
1849    let metric_families = prometheus::gather();
1850    match encoder.encode_to_string(&metric_families) {
1851        Ok(string) => (StatusCode::OK, string).into_response(),
1852        Err(error) => (
1853            StatusCode::INTERNAL_SERVER_ERROR,
1854            format!("failed to encode metrics {:?}", error),
1855        )
1856            .into_response(),
1857    }
1858}
1859
1860fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1861    match message {
1862        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1863        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1864        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1865        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1866        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1867            code: frame.code.into(),
1868            reason: frame.reason,
1869        })),
1870    }
1871}
1872
1873fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1874    match message {
1875        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1876        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1877        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1878        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1879        AxumMessage::Close(frame) => {
1880            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1881                code: frame.code.into(),
1882                reason: frame.reason,
1883            }))
1884        }
1885    }
1886}
1887
1888pub trait ResultExt {
1889    type Ok;
1890
1891    fn trace_err(self) -> Option<Self::Ok>;
1892}
1893
1894impl<T, E> ResultExt for Result<T, E>
1895where
1896    E: std::fmt::Debug,
1897{
1898    type Ok = T;
1899
1900    fn trace_err(self) -> Option<T> {
1901        match self {
1902            Ok(value) => Some(value),
1903            Err(error) => {
1904                tracing::error!("{:?}", error);
1905                None
1906            }
1907        }
1908    }
1909}