rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52pub use store::{Store, Worktree};
  53use tokio::{
  54    sync::{Mutex, MutexGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60lazy_static! {
  61    static ref METRIC_CONNECTIONS: IntGauge =
  62        register_int_gauge!("connections", "number of connections").unwrap();
  63    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  64        "shared_projects",
  65        "number of open projects with one or more guests"
  66    )
  67    .unwrap();
  68}
  69
  70type MessageHandler = Box<
  71    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  72>;
  73
  74struct Message<T> {
  75    sender_user_id: UserId,
  76    sender_connection_id: ConnectionId,
  77    payload: T,
  78}
  79
  80struct Response<R> {
  81    server: Arc<Server>,
  82    receipt: Receipt<R>,
  83    responded: Arc<AtomicBool>,
  84}
  85
  86impl<R: RequestMessage> Response<R> {
  87    fn send(self, payload: R::Response) -> Result<()> {
  88        self.responded.store(true, SeqCst);
  89        self.server.peer.respond(self.receipt, payload)?;
  90        Ok(())
  91    }
  92}
  93
  94pub struct Server {
  95    peer: Arc<Peer>,
  96    pub(crate) store: Mutex<Store>,
  97    app_state: Arc<AppState>,
  98    handlers: HashMap<TypeId, MessageHandler>,
  99}
 100
 101pub trait Executor: Send + Clone {
 102    type Sleep: Send + Future;
 103    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 104    fn sleep(&self, duration: Duration) -> Self::Sleep;
 105}
 106
 107#[derive(Clone)]
 108pub struct RealExecutor;
 109
 110pub(crate) struct StoreGuard<'a> {
 111    guard: MutexGuard<'a, Store>,
 112    _not_send: PhantomData<Rc<()>>,
 113}
 114
 115#[derive(Serialize)]
 116pub struct ServerSnapshot<'a> {
 117    peer: &'a Peer,
 118    #[serde(serialize_with = "serialize_deref")]
 119    store: StoreGuard<'a>,
 120}
 121
 122pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 123where
 124    S: Serializer,
 125    T: Deref<Target = U>,
 126    U: Serialize,
 127{
 128    Serialize::serialize(value.deref(), serializer)
 129}
 130
 131impl Server {
 132    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138        };
 139
 140        server
 141            .add_request_handler(Server::ping)
 142            .add_request_handler(Server::create_room)
 143            .add_request_handler(Server::join_room)
 144            .add_message_handler(Server::leave_room)
 145            .add_request_handler(Server::call)
 146            .add_request_handler(Server::cancel_call)
 147            .add_message_handler(Server::decline_call)
 148            .add_request_handler(Server::update_participant_location)
 149            .add_request_handler(Server::share_project)
 150            .add_message_handler(Server::unshare_project)
 151            .add_request_handler(Server::join_project)
 152            .add_message_handler(Server::leave_project)
 153            .add_request_handler(Server::update_project)
 154            .add_request_handler(Server::update_worktree)
 155            .add_message_handler(Server::start_language_server)
 156            .add_message_handler(Server::update_language_server)
 157            .add_request_handler(Server::update_diagnostic_summary)
 158            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 159            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 162            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 163            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 165            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 169            .add_request_handler(
 170                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 171            )
 172            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 173            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 174            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 175            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 177            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 179            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 182            .add_message_handler(Server::create_buffer_for_peer)
 183            .add_request_handler(Server::update_buffer)
 184            .add_message_handler(Server::update_buffer_file)
 185            .add_message_handler(Server::buffer_reloaded)
 186            .add_message_handler(Server::buffer_saved)
 187            .add_request_handler(Server::save_buffer)
 188            .add_request_handler(Server::get_users)
 189            .add_request_handler(Server::fuzzy_search_users)
 190            .add_request_handler(Server::request_contact)
 191            .add_request_handler(Server::remove_contact)
 192            .add_request_handler(Server::respond_to_contact_request)
 193            .add_request_handler(Server::follow)
 194            .add_message_handler(Server::unfollow)
 195            .add_message_handler(Server::update_followers)
 196            .add_message_handler(Server::update_diff_base)
 197            .add_request_handler(Server::get_private_user_info);
 198
 199        Arc::new(server)
 200    }
 201
 202    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 203    where
 204        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 205        Fut: 'static + Send + Future<Output = Result<()>>,
 206        M: EnvelopedMessage,
 207    {
 208        let prev_handler = self.handlers.insert(
 209            TypeId::of::<M>(),
 210            Box::new(move |server, sender_user_id, envelope| {
 211                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 212                let span = info_span!(
 213                    "handle message",
 214                    payload_type = envelope.payload_type_name()
 215                );
 216                span.in_scope(|| {
 217                    tracing::info!(
 218                        payload_type = envelope.payload_type_name(),
 219                        "message received"
 220                    );
 221                });
 222                let future = (handler)(server, sender_user_id, *envelope);
 223                async move {
 224                    if let Err(error) = future.await {
 225                        tracing::error!(%error, "error handling message");
 226                    }
 227                }
 228                .instrument(span)
 229                .boxed()
 230            }),
 231        );
 232        if prev_handler.is_some() {
 233            panic!("registered a handler for the same message twice");
 234        }
 235        self
 236    }
 237
 238    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 239    where
 240        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 241        Fut: 'static + Send + Future<Output = Result<()>>,
 242        M: EnvelopedMessage,
 243    {
 244        self.add_handler(move |server, sender_user_id, envelope| {
 245            handler(
 246                server,
 247                Message {
 248                    sender_user_id,
 249                    sender_connection_id: envelope.sender_id,
 250                    payload: envelope.payload,
 251                },
 252            )
 253        });
 254        self
 255    }
 256
 257    /// Handle a request while holding a lock to the store. This is useful when we're registering
 258    /// a connection but we want to respond on the connection before anybody else can send on it.
 259    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 260    where
 261        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 262        Fut: Send + Future<Output = Result<()>>,
 263        M: RequestMessage,
 264    {
 265        let handler = Arc::new(handler);
 266        self.add_handler(move |server, sender_user_id, envelope| {
 267            let receipt = envelope.receipt();
 268            let handler = handler.clone();
 269            async move {
 270                let request = Message {
 271                    sender_user_id,
 272                    sender_connection_id: envelope.sender_id,
 273                    payload: envelope.payload,
 274                };
 275                let responded = Arc::new(AtomicBool::default());
 276                let response = Response {
 277                    server: server.clone(),
 278                    responded: responded.clone(),
 279                    receipt,
 280                };
 281                match (handler)(server.clone(), request, response).await {
 282                    Ok(()) => {
 283                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 284                            Ok(())
 285                        } else {
 286                            Err(anyhow!("handler did not send a response"))?
 287                        }
 288                    }
 289                    Err(error) => {
 290                        server.peer.respond_with_error(
 291                            receipt,
 292                            proto::Error {
 293                                message: error.to_string(),
 294                            },
 295                        )?;
 296                        Err(error)
 297                    }
 298                }
 299            }
 300        })
 301    }
 302
 303    pub fn handle_connection<E: Executor>(
 304        self: &Arc<Self>,
 305        connection: Connection,
 306        address: String,
 307        user: User,
 308        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 309        executor: E,
 310    ) -> impl Future<Output = Result<()>> {
 311        let mut this = self.clone();
 312        let user_id = user.id;
 313        let login = user.github_login;
 314        let span = info_span!("handle connection", %user_id, %login, %address);
 315        async move {
 316            let (connection_id, handle_io, mut incoming_rx) = this
 317                .peer
 318                .add_connection(connection, {
 319                    let executor = executor.clone();
 320                    move |duration| {
 321                        let timer = executor.sleep(duration);
 322                        async move {
 323                            timer.await;
 324                        }
 325                    }
 326                });
 327
 328            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 329            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 330            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 331
 332            if let Some(send_connection_id) = send_connection_id.take() {
 333                let _ = send_connection_id.send(connection_id);
 334            }
 335
 336            if !user.connected_once {
 337                this.peer.send(connection_id, proto::ShowContacts {})?;
 338                this.app_state.db.set_user_connected_once(user_id, true).await?;
 339            }
 340
 341            let (contacts, invite_code) = future::try_join(
 342                this.app_state.db.get_contacts(user_id),
 343                this.app_state.db.get_invite_code_for_user(user_id)
 344            ).await?;
 345
 346            {
 347                let mut store = this.store().await;
 348                store.add_connection(connection_id, user_id, user.admin);
 349                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 350
 351                if let Some((code, count)) = invite_code {
 352                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 353                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 354                        count,
 355                    })?;
 356                }
 357            }
 358
 359            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 360                this.peer.send(connection_id, incoming_call)?;
 361            }
 362
 363            this.update_user_contacts(user_id).await?;
 364
 365            let handle_io = handle_io.fuse();
 366            futures::pin_mut!(handle_io);
 367
 368            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 369            // This prevents deadlocks when e.g., client A performs a request to client B and
 370            // client B performs a request to client A. If both clients stop processing further
 371            // messages until their respective request completes, they won't have a chance to
 372            // respond to the other client's request and cause a deadlock.
 373            //
 374            // This arrangement ensures we will attempt to process earlier messages first, but fall
 375            // back to processing messages arrived later in the spirit of making progress.
 376            let mut foreground_message_handlers = FuturesUnordered::new();
 377            loop {
 378                let next_message = incoming_rx.next().fuse();
 379                futures::pin_mut!(next_message);
 380                futures::select_biased! {
 381                    result = handle_io => {
 382                        if let Err(error) = result {
 383                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 384                        }
 385                        break;
 386                    }
 387                    _ = foreground_message_handlers.next() => {}
 388                    message = next_message => {
 389                        if let Some(message) = message {
 390                            let type_name = message.payload_type_name();
 391                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 392                            let span_enter = span.enter();
 393                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 394                                let is_background = message.is_background();
 395                                let handle_message = (handler)(this.clone(), user_id, message);
 396                                drop(span_enter);
 397
 398                                let handle_message = handle_message.instrument(span);
 399                                if is_background {
 400                                    executor.spawn_detached(handle_message);
 401                                } else {
 402                                    foreground_message_handlers.push(handle_message);
 403                                }
 404                            } else {
 405                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 406                            }
 407                        } else {
 408                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 409                            break;
 410                        }
 411                    }
 412                }
 413            }
 414
 415            drop(foreground_message_handlers);
 416            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 417            if let Err(error) = this.sign_out(connection_id, user_id).await {
 418                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 419            }
 420
 421            Ok(())
 422        }.instrument(span)
 423    }
 424
 425    #[instrument(skip(self), err)]
 426    async fn sign_out(
 427        self: &mut Arc<Self>,
 428        connection_id: ConnectionId,
 429        user_id: UserId,
 430    ) -> Result<()> {
 431        self.peer.disconnect(connection_id);
 432        let decline_calls = {
 433            let mut store = self.store().await;
 434            store.remove_connection(connection_id)?;
 435            let mut connections = store.connection_ids_for_user(user_id);
 436            connections.next().is_none()
 437        };
 438
 439        self.leave_room_for_connection(connection_id, user_id)
 440            .await
 441            .trace_err();
 442        if decline_calls {
 443            if let Some(room) = self
 444                .app_state
 445                .db
 446                .decline_call(None, user_id)
 447                .await
 448                .trace_err()
 449            {
 450                self.room_updated(&room);
 451            }
 452        }
 453
 454        self.update_user_contacts(user_id).await?;
 455
 456        Ok(())
 457    }
 458
 459    pub async fn invite_code_redeemed(
 460        self: &Arc<Self>,
 461        inviter_id: UserId,
 462        invitee_id: UserId,
 463    ) -> Result<()> {
 464        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 465            if let Some(code) = &user.invite_code {
 466                let store = self.store().await;
 467                let invitee_contact = store.contact_for_user(invitee_id, true, false);
 468                for connection_id in store.connection_ids_for_user(inviter_id) {
 469                    self.peer.send(
 470                        connection_id,
 471                        proto::UpdateContacts {
 472                            contacts: vec![invitee_contact.clone()],
 473                            ..Default::default()
 474                        },
 475                    )?;
 476                    self.peer.send(
 477                        connection_id,
 478                        proto::UpdateInviteInfo {
 479                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 480                            count: user.invite_count as u32,
 481                        },
 482                    )?;
 483                }
 484            }
 485        }
 486        Ok(())
 487    }
 488
 489    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 490        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 491            if let Some(invite_code) = &user.invite_code {
 492                let store = self.store().await;
 493                for connection_id in store.connection_ids_for_user(user_id) {
 494                    self.peer.send(
 495                        connection_id,
 496                        proto::UpdateInviteInfo {
 497                            url: format!(
 498                                "{}{}",
 499                                self.app_state.config.invite_link_prefix, invite_code
 500                            ),
 501                            count: user.invite_count as u32,
 502                        },
 503                    )?;
 504                }
 505            }
 506        }
 507        Ok(())
 508    }
 509
 510    async fn ping(
 511        self: Arc<Server>,
 512        _: Message<proto::Ping>,
 513        response: Response<proto::Ping>,
 514    ) -> Result<()> {
 515        response.send(proto::Ack {})?;
 516        Ok(())
 517    }
 518
 519    async fn create_room(
 520        self: Arc<Server>,
 521        request: Message<proto::CreateRoom>,
 522        response: Response<proto::CreateRoom>,
 523    ) -> Result<()> {
 524        let room = self
 525            .app_state
 526            .db
 527            .create_room(request.sender_user_id, request.sender_connection_id)
 528            .await?;
 529
 530        let live_kit_connection_info =
 531            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 532                if let Some(_) = live_kit
 533                    .create_room(room.live_kit_room.clone())
 534                    .await
 535                    .trace_err()
 536                {
 537                    if let Some(token) = live_kit
 538                        .room_token(
 539                            &room.live_kit_room,
 540                            &request.sender_connection_id.to_string(),
 541                        )
 542                        .trace_err()
 543                    {
 544                        Some(proto::LiveKitConnectionInfo {
 545                            server_url: live_kit.url().into(),
 546                            token,
 547                        })
 548                    } else {
 549                        None
 550                    }
 551                } else {
 552                    None
 553                }
 554            } else {
 555                None
 556            };
 557
 558        response.send(proto::CreateRoomResponse {
 559            room: Some(room),
 560            live_kit_connection_info,
 561        })?;
 562        self.update_user_contacts(request.sender_user_id).await?;
 563        Ok(())
 564    }
 565
 566    async fn join_room(
 567        self: Arc<Server>,
 568        request: Message<proto::JoinRoom>,
 569        response: Response<proto::JoinRoom>,
 570    ) -> Result<()> {
 571        let room = self
 572            .app_state
 573            .db
 574            .join_room(
 575                RoomId::from_proto(request.payload.id),
 576                request.sender_user_id,
 577                request.sender_connection_id,
 578            )
 579            .await?;
 580        for connection_id in self
 581            .store()
 582            .await
 583            .connection_ids_for_user(request.sender_user_id)
 584        {
 585            self.peer
 586                .send(connection_id, proto::CallCanceled {})
 587                .trace_err();
 588        }
 589
 590        let live_kit_connection_info =
 591            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 592                if let Some(token) = live_kit
 593                    .room_token(
 594                        &room.live_kit_room,
 595                        &request.sender_connection_id.to_string(),
 596                    )
 597                    .trace_err()
 598                {
 599                    Some(proto::LiveKitConnectionInfo {
 600                        server_url: live_kit.url().into(),
 601                        token,
 602                    })
 603                } else {
 604                    None
 605                }
 606            } else {
 607                None
 608            };
 609
 610        self.room_updated(&room);
 611        response.send(proto::JoinRoomResponse {
 612            room: Some(room),
 613            live_kit_connection_info,
 614        })?;
 615
 616        self.update_user_contacts(request.sender_user_id).await?;
 617        Ok(())
 618    }
 619
 620    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 621        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
 622            .await
 623    }
 624
 625    async fn leave_room_for_connection(
 626        self: &Arc<Server>,
 627        leaving_connection_id: ConnectionId,
 628        leaving_user_id: UserId,
 629    ) -> Result<()> {
 630        let mut contacts_to_update = HashSet::default();
 631
 632        let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else {
 633            return Err(anyhow!("no room to leave"))?;
 634        };
 635        contacts_to_update.insert(leaving_user_id);
 636
 637        for project in left_room.left_projects.into_values() {
 638            for connection_id in project.connection_ids {
 639                if project.host_user_id == leaving_user_id {
 640                    self.peer
 641                        .send(
 642                            connection_id,
 643                            proto::UnshareProject {
 644                                project_id: project.id.to_proto(),
 645                            },
 646                        )
 647                        .trace_err();
 648                } else {
 649                    self.peer
 650                        .send(
 651                            connection_id,
 652                            proto::RemoveProjectCollaborator {
 653                                project_id: project.id.to_proto(),
 654                                peer_id: leaving_connection_id.0,
 655                            },
 656                        )
 657                        .trace_err();
 658                }
 659            }
 660
 661            self.peer
 662                .send(
 663                    leaving_connection_id,
 664                    proto::UnshareProject {
 665                        project_id: project.id.to_proto(),
 666                    },
 667                )
 668                .trace_err();
 669        }
 670
 671        self.room_updated(&left_room.room);
 672        {
 673            let store = self.store().await;
 674            for canceled_user_id in left_room.canceled_calls_to_user_ids {
 675                for connection_id in store.connection_ids_for_user(canceled_user_id) {
 676                    self.peer
 677                        .send(connection_id, proto::CallCanceled {})
 678                        .trace_err();
 679                }
 680                contacts_to_update.insert(canceled_user_id);
 681            }
 682        }
 683
 684        for contact_user_id in contacts_to_update {
 685            self.update_user_contacts(contact_user_id).await?;
 686        }
 687
 688        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 689            live_kit
 690                .remove_participant(
 691                    left_room.room.live_kit_room.clone(),
 692                    leaving_connection_id.to_string(),
 693                )
 694                .await
 695                .trace_err();
 696
 697            if left_room.room.participants.is_empty() {
 698                live_kit
 699                    .delete_room(left_room.room.live_kit_room)
 700                    .await
 701                    .trace_err();
 702            }
 703        }
 704
 705        Ok(())
 706    }
 707
 708    async fn call(
 709        self: Arc<Server>,
 710        request: Message<proto::Call>,
 711        response: Response<proto::Call>,
 712    ) -> Result<()> {
 713        let room_id = RoomId::from_proto(request.payload.room_id);
 714        let calling_user_id = request.sender_user_id;
 715        let calling_connection_id = request.sender_connection_id;
 716        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 717        let initial_project_id = request
 718            .payload
 719            .initial_project_id
 720            .map(ProjectId::from_proto);
 721        if !self
 722            .app_state
 723            .db
 724            .has_contact(calling_user_id, called_user_id)
 725            .await?
 726        {
 727            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 728        }
 729
 730        let (room, incoming_call) = self
 731            .app_state
 732            .db
 733            .call(
 734                room_id,
 735                calling_user_id,
 736                calling_connection_id,
 737                called_user_id,
 738                initial_project_id,
 739            )
 740            .await?;
 741        self.room_updated(&room);
 742        self.update_user_contacts(called_user_id).await?;
 743
 744        let mut calls = self
 745            .store()
 746            .await
 747            .connection_ids_for_user(called_user_id)
 748            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 749            .collect::<FuturesUnordered<_>>();
 750
 751        while let Some(call_response) = calls.next().await {
 752            match call_response.as_ref() {
 753                Ok(_) => {
 754                    response.send(proto::Ack {})?;
 755                    return Ok(());
 756                }
 757                Err(_) => {
 758                    call_response.trace_err();
 759                }
 760            }
 761        }
 762
 763        let room = self
 764            .app_state
 765            .db
 766            .call_failed(room_id, called_user_id)
 767            .await?;
 768        self.room_updated(&room);
 769        self.update_user_contacts(called_user_id).await?;
 770
 771        Err(anyhow!("failed to ring user"))?
 772    }
 773
 774    async fn cancel_call(
 775        self: Arc<Server>,
 776        request: Message<proto::CancelCall>,
 777        response: Response<proto::CancelCall>,
 778    ) -> Result<()> {
 779        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 780        let room_id = RoomId::from_proto(request.payload.room_id);
 781        let room = self
 782            .app_state
 783            .db
 784            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
 785            .await?;
 786        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
 787            self.peer
 788                .send(connection_id, proto::CallCanceled {})
 789                .trace_err();
 790        }
 791        self.room_updated(&room);
 792        response.send(proto::Ack {})?;
 793
 794        self.update_user_contacts(called_user_id).await?;
 795        Ok(())
 796    }
 797
 798    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 799        let room_id = RoomId::from_proto(message.payload.room_id);
 800        let room = self
 801            .app_state
 802            .db
 803            .decline_call(Some(room_id), message.sender_user_id)
 804            .await?;
 805        for connection_id in self
 806            .store()
 807            .await
 808            .connection_ids_for_user(message.sender_user_id)
 809        {
 810            self.peer
 811                .send(connection_id, proto::CallCanceled {})
 812                .trace_err();
 813        }
 814        self.room_updated(&room);
 815        self.update_user_contacts(message.sender_user_id).await?;
 816        Ok(())
 817    }
 818
 819    async fn update_participant_location(
 820        self: Arc<Server>,
 821        request: Message<proto::UpdateParticipantLocation>,
 822        response: Response<proto::UpdateParticipantLocation>,
 823    ) -> Result<()> {
 824        let room_id = RoomId::from_proto(request.payload.room_id);
 825        let location = request
 826            .payload
 827            .location
 828            .ok_or_else(|| anyhow!("invalid location"))?;
 829        let room = self
 830            .app_state
 831            .db
 832            .update_room_participant_location(room_id, request.sender_connection_id, location)
 833            .await?;
 834        self.room_updated(&room);
 835        response.send(proto::Ack {})?;
 836        Ok(())
 837    }
 838
 839    fn room_updated(&self, room: &proto::Room) {
 840        for participant in &room.participants {
 841            self.peer
 842                .send(
 843                    ConnectionId(participant.peer_id),
 844                    proto::RoomUpdated {
 845                        room: Some(room.clone()),
 846                    },
 847                )
 848                .trace_err();
 849        }
 850    }
 851
 852    async fn share_project(
 853        self: Arc<Server>,
 854        request: Message<proto::ShareProject>,
 855        response: Response<proto::ShareProject>,
 856    ) -> Result<()> {
 857        let (project_id, room) = self
 858            .app_state
 859            .db
 860            .share_project(
 861                RoomId::from_proto(request.payload.room_id),
 862                request.sender_connection_id,
 863                &request.payload.worktrees,
 864            )
 865            .await
 866            .unwrap();
 867        response.send(proto::ShareProjectResponse {
 868            project_id: project_id.to_proto(),
 869        })?;
 870        self.room_updated(&room);
 871
 872        Ok(())
 873    }
 874
 875    async fn unshare_project(
 876        self: Arc<Server>,
 877        message: Message<proto::UnshareProject>,
 878    ) -> Result<()> {
 879        let project_id = ProjectId::from_proto(message.payload.project_id);
 880        let mut store = self.store().await;
 881        let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?;
 882        broadcast(
 883            message.sender_connection_id,
 884            project.guest_connection_ids(),
 885            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 886        );
 887        self.room_updated(room);
 888
 889        Ok(())
 890    }
 891
 892    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 893        let contacts = self.app_state.db.get_contacts(user_id).await?;
 894        let busy = self.app_state.db.is_user_busy(user_id).await?;
 895        let store = self.store().await;
 896        let updated_contact = store.contact_for_user(user_id, false, busy);
 897        for contact in contacts {
 898            if let db::Contact::Accepted {
 899                user_id: contact_user_id,
 900                ..
 901            } = contact
 902            {
 903                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 904                    self.peer
 905                        .send(
 906                            contact_conn_id,
 907                            proto::UpdateContacts {
 908                                contacts: vec![updated_contact.clone()],
 909                                remove_contacts: Default::default(),
 910                                incoming_requests: Default::default(),
 911                                remove_incoming_requests: Default::default(),
 912                                outgoing_requests: Default::default(),
 913                                remove_outgoing_requests: Default::default(),
 914                            },
 915                        )
 916                        .trace_err();
 917                }
 918            }
 919        }
 920        Ok(())
 921    }
 922
 923    async fn join_project(
 924        self: Arc<Server>,
 925        request: Message<proto::JoinProject>,
 926        response: Response<proto::JoinProject>,
 927    ) -> Result<()> {
 928        let project_id = ProjectId::from_proto(request.payload.project_id);
 929        let guest_user_id = request.sender_user_id;
 930
 931        tracing::info!(%project_id, "join project");
 932
 933        let (project, replica_id) = self
 934            .app_state
 935            .db
 936            .join_project(project_id, request.sender_connection_id)
 937            .await?;
 938
 939        let collaborators = project
 940            .collaborators
 941            .iter()
 942            .filter(|collaborator| {
 943                collaborator.connection_id != request.sender_connection_id.0 as i32
 944            })
 945            .map(|collaborator| proto::Collaborator {
 946                peer_id: collaborator.connection_id as u32,
 947                replica_id: collaborator.replica_id.0 as u32,
 948                user_id: collaborator.user_id.to_proto(),
 949            })
 950            .collect::<Vec<_>>();
 951        let worktrees = project
 952            .worktrees
 953            .iter()
 954            .map(|(id, worktree)| proto::WorktreeMetadata {
 955                id: id.to_proto(),
 956                root_name: worktree.root_name.clone(),
 957                visible: worktree.visible,
 958                abs_path: worktree.abs_path.clone(),
 959            })
 960            .collect::<Vec<_>>();
 961
 962        for collaborator in &collaborators {
 963            self.peer
 964                .send(
 965                    ConnectionId(collaborator.peer_id),
 966                    proto::AddProjectCollaborator {
 967                        project_id: project_id.to_proto(),
 968                        collaborator: Some(proto::Collaborator {
 969                            peer_id: request.sender_connection_id.0,
 970                            replica_id: replica_id.0 as u32,
 971                            user_id: guest_user_id.to_proto(),
 972                        }),
 973                    },
 974                )
 975                .trace_err();
 976        }
 977
 978        // First, we send the metadata associated with each worktree.
 979        response.send(proto::JoinProjectResponse {
 980            worktrees: worktrees.clone(),
 981            replica_id: replica_id.0 as u32,
 982            collaborators: collaborators.clone(),
 983            language_servers: project.language_servers.clone(),
 984        })?;
 985
 986        for (worktree_id, worktree) in project.worktrees {
 987            #[cfg(any(test, feature = "test-support"))]
 988            const MAX_CHUNK_SIZE: usize = 2;
 989            #[cfg(not(any(test, feature = "test-support")))]
 990            const MAX_CHUNK_SIZE: usize = 256;
 991
 992            // Stream this worktree's entries.
 993            let message = proto::UpdateWorktree {
 994                project_id: project_id.to_proto(),
 995                worktree_id: worktree_id.to_proto(),
 996                abs_path: worktree.abs_path.clone(),
 997                root_name: worktree.root_name,
 998                updated_entries: worktree.entries,
 999                removed_entries: Default::default(),
1000                scan_id: worktree.scan_id,
1001                is_last_update: worktree.is_complete,
1002            };
1003            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1004                self.peer
1005                    .send(request.sender_connection_id, update.clone())?;
1006            }
1007
1008            // Stream this worktree's diagnostics.
1009            for summary in worktree.diagnostic_summaries {
1010                self.peer.send(
1011                    request.sender_connection_id,
1012                    proto::UpdateDiagnosticSummary {
1013                        project_id: project_id.to_proto(),
1014                        worktree_id: worktree.id.to_proto(),
1015                        summary: Some(summary),
1016                    },
1017                )?;
1018            }
1019        }
1020
1021        for language_server in &project.language_servers {
1022            self.peer.send(
1023                request.sender_connection_id,
1024                proto::UpdateLanguageServer {
1025                    project_id: project_id.to_proto(),
1026                    language_server_id: language_server.id,
1027                    variant: Some(
1028                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1029                            proto::LspDiskBasedDiagnosticsUpdated {},
1030                        ),
1031                    ),
1032                },
1033            )?;
1034        }
1035
1036        Ok(())
1037    }
1038
1039    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1040        let sender_id = request.sender_connection_id;
1041        let project_id = ProjectId::from_proto(request.payload.project_id);
1042        let project;
1043        {
1044            let mut store = self.store().await;
1045            project = store.leave_project(project_id, sender_id)?;
1046            tracing::info!(
1047                %project_id,
1048                host_user_id = %project.host_user_id,
1049                host_connection_id = %project.host_connection_id,
1050                "leave project"
1051            );
1052
1053            if project.remove_collaborator {
1054                broadcast(sender_id, project.connection_ids, |conn_id| {
1055                    self.peer.send(
1056                        conn_id,
1057                        proto::RemoveProjectCollaborator {
1058                            project_id: project_id.to_proto(),
1059                            peer_id: sender_id.0,
1060                        },
1061                    )
1062                });
1063            }
1064        }
1065
1066        Ok(())
1067    }
1068
1069    async fn update_project(
1070        self: Arc<Server>,
1071        request: Message<proto::UpdateProject>,
1072        response: Response<proto::UpdateProject>,
1073    ) -> Result<()> {
1074        let project_id = ProjectId::from_proto(request.payload.project_id);
1075        let (room, guest_connection_ids) = self
1076            .app_state
1077            .db
1078            .update_project(
1079                project_id,
1080                request.sender_connection_id,
1081                &request.payload.worktrees,
1082            )
1083            .await?;
1084        broadcast(
1085            request.sender_connection_id,
1086            guest_connection_ids,
1087            |connection_id| {
1088                self.peer.forward_send(
1089                    request.sender_connection_id,
1090                    connection_id,
1091                    request.payload.clone(),
1092                )
1093            },
1094        );
1095        self.room_updated(&room);
1096        response.send(proto::Ack {})?;
1097
1098        Ok(())
1099    }
1100
1101    async fn update_worktree(
1102        self: Arc<Server>,
1103        request: Message<proto::UpdateWorktree>,
1104        response: Response<proto::UpdateWorktree>,
1105    ) -> Result<()> {
1106        let guest_connection_ids = self
1107            .app_state
1108            .db
1109            .update_worktree(&request.payload, request.sender_connection_id)
1110            .await?;
1111
1112        broadcast(
1113            request.sender_connection_id,
1114            guest_connection_ids,
1115            |connection_id| {
1116                self.peer.forward_send(
1117                    request.sender_connection_id,
1118                    connection_id,
1119                    request.payload.clone(),
1120                )
1121            },
1122        );
1123        response.send(proto::Ack {})?;
1124        Ok(())
1125    }
1126
1127    async fn update_diagnostic_summary(
1128        self: Arc<Server>,
1129        request: Message<proto::UpdateDiagnosticSummary>,
1130        response: Response<proto::UpdateDiagnosticSummary>,
1131    ) -> Result<()> {
1132        let guest_connection_ids = self
1133            .app_state
1134            .db
1135            .update_diagnostic_summary(&request.payload, request.sender_connection_id)
1136            .await?;
1137
1138        broadcast(
1139            request.sender_connection_id,
1140            guest_connection_ids,
1141            |connection_id| {
1142                self.peer.forward_send(
1143                    request.sender_connection_id,
1144                    connection_id,
1145                    request.payload.clone(),
1146                )
1147            },
1148        );
1149
1150        response.send(proto::Ack {})?;
1151        Ok(())
1152    }
1153
1154    async fn start_language_server(
1155        self: Arc<Server>,
1156        request: Message<proto::StartLanguageServer>,
1157    ) -> Result<()> {
1158        let guest_connection_ids = self
1159            .app_state
1160            .db
1161            .start_language_server(&request.payload, request.sender_connection_id)
1162            .await?;
1163
1164        broadcast(
1165            request.sender_connection_id,
1166            guest_connection_ids,
1167            |connection_id| {
1168                self.peer.forward_send(
1169                    request.sender_connection_id,
1170                    connection_id,
1171                    request.payload.clone(),
1172                )
1173            },
1174        );
1175        Ok(())
1176    }
1177
1178    async fn update_language_server(
1179        self: Arc<Server>,
1180        request: Message<proto::UpdateLanguageServer>,
1181    ) -> Result<()> {
1182        let project_id = ProjectId::from_proto(request.payload.project_id);
1183        let project_connection_ids = self
1184            .app_state
1185            .db
1186            .project_connection_ids(project_id, request.sender_connection_id)
1187            .await?;
1188        broadcast(
1189            request.sender_connection_id,
1190            project_connection_ids,
1191            |connection_id| {
1192                self.peer.forward_send(
1193                    request.sender_connection_id,
1194                    connection_id,
1195                    request.payload.clone(),
1196                )
1197            },
1198        );
1199        Ok(())
1200    }
1201
1202    async fn forward_project_request<T>(
1203        self: Arc<Server>,
1204        request: Message<T>,
1205        response: Response<T>,
1206    ) -> Result<()>
1207    where
1208        T: EntityMessage + RequestMessage,
1209    {
1210        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1211        let collaborators = self
1212            .app_state
1213            .db
1214            .project_collaborators(project_id, request.sender_connection_id)
1215            .await?;
1216        let host = collaborators
1217            .iter()
1218            .find(|collaborator| collaborator.is_host)
1219            .ok_or_else(|| anyhow!("host not found"))?;
1220
1221        let payload = self
1222            .peer
1223            .forward_request(
1224                request.sender_connection_id,
1225                ConnectionId(host.connection_id as u32),
1226                request.payload,
1227            )
1228            .await?;
1229
1230        response.send(payload)?;
1231        Ok(())
1232    }
1233
1234    async fn save_buffer(
1235        self: Arc<Server>,
1236        request: Message<proto::SaveBuffer>,
1237        response: Response<proto::SaveBuffer>,
1238    ) -> Result<()> {
1239        let project_id = ProjectId::from_proto(request.payload.project_id);
1240        let collaborators = self
1241            .app_state
1242            .db
1243            .project_collaborators(project_id, request.sender_connection_id)
1244            .await?;
1245        let host = collaborators
1246            .into_iter()
1247            .find(|collaborator| collaborator.is_host)
1248            .ok_or_else(|| anyhow!("host not found"))?;
1249        let host_connection_id = ConnectionId(host.connection_id as u32);
1250        let response_payload = self
1251            .peer
1252            .forward_request(
1253                request.sender_connection_id,
1254                host_connection_id,
1255                request.payload.clone(),
1256            )
1257            .await?;
1258
1259        let mut collaborators = self
1260            .app_state
1261            .db
1262            .project_collaborators(project_id, request.sender_connection_id)
1263            .await?;
1264        collaborators.retain(|collaborator| {
1265            collaborator.connection_id != request.sender_connection_id.0 as i32
1266        });
1267        let project_connection_ids = collaborators
1268            .into_iter()
1269            .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1270        broadcast(host_connection_id, project_connection_ids, |conn_id| {
1271            self.peer
1272                .forward_send(host_connection_id, conn_id, response_payload.clone())
1273        });
1274        response.send(response_payload)?;
1275        Ok(())
1276    }
1277
1278    async fn create_buffer_for_peer(
1279        self: Arc<Server>,
1280        request: Message<proto::CreateBufferForPeer>,
1281    ) -> Result<()> {
1282        self.peer.forward_send(
1283            request.sender_connection_id,
1284            ConnectionId(request.payload.peer_id),
1285            request.payload,
1286        )?;
1287        Ok(())
1288    }
1289
1290    async fn update_buffer(
1291        self: Arc<Server>,
1292        request: Message<proto::UpdateBuffer>,
1293        response: Response<proto::UpdateBuffer>,
1294    ) -> Result<()> {
1295        let project_id = ProjectId::from_proto(request.payload.project_id);
1296        let project_connection_ids = self
1297            .app_state
1298            .db
1299            .project_connection_ids(project_id, request.sender_connection_id)
1300            .await?;
1301
1302        broadcast(
1303            request.sender_connection_id,
1304            project_connection_ids,
1305            |connection_id| {
1306                self.peer.forward_send(
1307                    request.sender_connection_id,
1308                    connection_id,
1309                    request.payload.clone(),
1310                )
1311            },
1312        );
1313        response.send(proto::Ack {})?;
1314        Ok(())
1315    }
1316
1317    async fn update_buffer_file(
1318        self: Arc<Server>,
1319        request: Message<proto::UpdateBufferFile>,
1320    ) -> Result<()> {
1321        let project_id = ProjectId::from_proto(request.payload.project_id);
1322        let project_connection_ids = self
1323            .app_state
1324            .db
1325            .project_connection_ids(project_id, request.sender_connection_id)
1326            .await?;
1327
1328        broadcast(
1329            request.sender_connection_id,
1330            project_connection_ids,
1331            |connection_id| {
1332                self.peer.forward_send(
1333                    request.sender_connection_id,
1334                    connection_id,
1335                    request.payload.clone(),
1336                )
1337            },
1338        );
1339        Ok(())
1340    }
1341
1342    async fn buffer_reloaded(
1343        self: Arc<Server>,
1344        request: Message<proto::BufferReloaded>,
1345    ) -> Result<()> {
1346        let project_id = ProjectId::from_proto(request.payload.project_id);
1347        let project_connection_ids = self
1348            .app_state
1349            .db
1350            .project_connection_ids(project_id, request.sender_connection_id)
1351            .await?;
1352        broadcast(
1353            request.sender_connection_id,
1354            project_connection_ids,
1355            |connection_id| {
1356                self.peer.forward_send(
1357                    request.sender_connection_id,
1358                    connection_id,
1359                    request.payload.clone(),
1360                )
1361            },
1362        );
1363        Ok(())
1364    }
1365
1366    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1367        let project_id = ProjectId::from_proto(request.payload.project_id);
1368        let project_connection_ids = self
1369            .app_state
1370            .db
1371            .project_connection_ids(project_id, request.sender_connection_id)
1372            .await?;
1373        broadcast(
1374            request.sender_connection_id,
1375            project_connection_ids,
1376            |connection_id| {
1377                self.peer.forward_send(
1378                    request.sender_connection_id,
1379                    connection_id,
1380                    request.payload.clone(),
1381                )
1382            },
1383        );
1384        Ok(())
1385    }
1386
1387    async fn follow(
1388        self: Arc<Self>,
1389        request: Message<proto::Follow>,
1390        response: Response<proto::Follow>,
1391    ) -> Result<()> {
1392        let project_id = ProjectId::from_proto(request.payload.project_id);
1393        let leader_id = ConnectionId(request.payload.leader_id);
1394        let follower_id = request.sender_connection_id;
1395        let project_connection_ids = self
1396            .app_state
1397            .db
1398            .project_connection_ids(project_id, request.sender_connection_id)
1399            .await?;
1400
1401        if !project_connection_ids.contains(&leader_id) {
1402            Err(anyhow!("no such peer"))?;
1403        }
1404
1405        let mut response_payload = self
1406            .peer
1407            .forward_request(request.sender_connection_id, leader_id, request.payload)
1408            .await?;
1409        response_payload
1410            .views
1411            .retain(|view| view.leader_id != Some(follower_id.0));
1412        response.send(response_payload)?;
1413        Ok(())
1414    }
1415
1416    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1417        let project_id = ProjectId::from_proto(request.payload.project_id);
1418        let leader_id = ConnectionId(request.payload.leader_id);
1419        let project_connection_ids = self
1420            .app_state
1421            .db
1422            .project_connection_ids(project_id, request.sender_connection_id)
1423            .await?;
1424        if !project_connection_ids.contains(&leader_id) {
1425            Err(anyhow!("no such peer"))?;
1426        }
1427        self.peer
1428            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1429        Ok(())
1430    }
1431
1432    async fn update_followers(
1433        self: Arc<Self>,
1434        request: Message<proto::UpdateFollowers>,
1435    ) -> Result<()> {
1436        let project_id = ProjectId::from_proto(request.payload.project_id);
1437        let project_connection_ids = self
1438            .app_state
1439            .db
1440            .project_connection_ids(project_id, request.sender_connection_id)
1441            .await?;
1442
1443        let leader_id = request
1444            .payload
1445            .variant
1446            .as_ref()
1447            .and_then(|variant| match variant {
1448                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1449                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1450                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1451            });
1452        for follower_id in &request.payload.follower_ids {
1453            let follower_id = ConnectionId(*follower_id);
1454            if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1455                self.peer.forward_send(
1456                    request.sender_connection_id,
1457                    follower_id,
1458                    request.payload.clone(),
1459                )?;
1460            }
1461        }
1462        Ok(())
1463    }
1464
1465    async fn get_users(
1466        self: Arc<Server>,
1467        request: Message<proto::GetUsers>,
1468        response: Response<proto::GetUsers>,
1469    ) -> Result<()> {
1470        let user_ids = request
1471            .payload
1472            .user_ids
1473            .into_iter()
1474            .map(UserId::from_proto)
1475            .collect();
1476        let users = self
1477            .app_state
1478            .db
1479            .get_users_by_ids(user_ids)
1480            .await?
1481            .into_iter()
1482            .map(|user| proto::User {
1483                id: user.id.to_proto(),
1484                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1485                github_login: user.github_login,
1486            })
1487            .collect();
1488        response.send(proto::UsersResponse { users })?;
1489        Ok(())
1490    }
1491
1492    async fn fuzzy_search_users(
1493        self: Arc<Server>,
1494        request: Message<proto::FuzzySearchUsers>,
1495        response: Response<proto::FuzzySearchUsers>,
1496    ) -> Result<()> {
1497        let query = request.payload.query;
1498        let db = &self.app_state.db;
1499        let users = match query.len() {
1500            0 => vec![],
1501            1 | 2 => db
1502                .get_user_by_github_account(&query, None)
1503                .await?
1504                .into_iter()
1505                .collect(),
1506            _ => db.fuzzy_search_users(&query, 10).await?,
1507        };
1508        let users = users
1509            .into_iter()
1510            .filter(|user| user.id != request.sender_user_id)
1511            .map(|user| proto::User {
1512                id: user.id.to_proto(),
1513                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1514                github_login: user.github_login,
1515            })
1516            .collect();
1517        response.send(proto::UsersResponse { users })?;
1518        Ok(())
1519    }
1520
1521    async fn request_contact(
1522        self: Arc<Server>,
1523        request: Message<proto::RequestContact>,
1524        response: Response<proto::RequestContact>,
1525    ) -> Result<()> {
1526        let requester_id = request.sender_user_id;
1527        let responder_id = UserId::from_proto(request.payload.responder_id);
1528        if requester_id == responder_id {
1529            return Err(anyhow!("cannot add yourself as a contact"))?;
1530        }
1531
1532        self.app_state
1533            .db
1534            .send_contact_request(requester_id, responder_id)
1535            .await?;
1536
1537        // Update outgoing contact requests of requester
1538        let mut update = proto::UpdateContacts::default();
1539        update.outgoing_requests.push(responder_id.to_proto());
1540        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1541            self.peer.send(connection_id, update.clone())?;
1542        }
1543
1544        // Update incoming contact requests of responder
1545        let mut update = proto::UpdateContacts::default();
1546        update
1547            .incoming_requests
1548            .push(proto::IncomingContactRequest {
1549                requester_id: requester_id.to_proto(),
1550                should_notify: true,
1551            });
1552        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1553            self.peer.send(connection_id, update.clone())?;
1554        }
1555
1556        response.send(proto::Ack {})?;
1557        Ok(())
1558    }
1559
1560    async fn respond_to_contact_request(
1561        self: Arc<Server>,
1562        request: Message<proto::RespondToContactRequest>,
1563        response: Response<proto::RespondToContactRequest>,
1564    ) -> Result<()> {
1565        let responder_id = request.sender_user_id;
1566        let requester_id = UserId::from_proto(request.payload.requester_id);
1567        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1568            self.app_state
1569                .db
1570                .dismiss_contact_notification(responder_id, requester_id)
1571                .await?;
1572        } else {
1573            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1574            self.app_state
1575                .db
1576                .respond_to_contact_request(responder_id, requester_id, accept)
1577                .await?;
1578            let busy = self.app_state.db.is_user_busy(requester_id).await?;
1579
1580            let store = self.store().await;
1581            // Update responder with new contact
1582            let mut update = proto::UpdateContacts::default();
1583            if accept {
1584                update
1585                    .contacts
1586                    .push(store.contact_for_user(requester_id, false, busy));
1587            }
1588            update
1589                .remove_incoming_requests
1590                .push(requester_id.to_proto());
1591            for connection_id in store.connection_ids_for_user(responder_id) {
1592                self.peer.send(connection_id, update.clone())?;
1593            }
1594
1595            // Update requester with new contact
1596            let mut update = proto::UpdateContacts::default();
1597            if accept {
1598                update
1599                    .contacts
1600                    .push(store.contact_for_user(responder_id, true, busy));
1601            }
1602            update
1603                .remove_outgoing_requests
1604                .push(responder_id.to_proto());
1605            for connection_id in store.connection_ids_for_user(requester_id) {
1606                self.peer.send(connection_id, update.clone())?;
1607            }
1608        }
1609
1610        response.send(proto::Ack {})?;
1611        Ok(())
1612    }
1613
1614    async fn remove_contact(
1615        self: Arc<Server>,
1616        request: Message<proto::RemoveContact>,
1617        response: Response<proto::RemoveContact>,
1618    ) -> Result<()> {
1619        let requester_id = request.sender_user_id;
1620        let responder_id = UserId::from_proto(request.payload.user_id);
1621        self.app_state
1622            .db
1623            .remove_contact(requester_id, responder_id)
1624            .await?;
1625
1626        // Update outgoing contact requests of requester
1627        let mut update = proto::UpdateContacts::default();
1628        update
1629            .remove_outgoing_requests
1630            .push(responder_id.to_proto());
1631        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1632            self.peer.send(connection_id, update.clone())?;
1633        }
1634
1635        // Update incoming contact requests of responder
1636        let mut update = proto::UpdateContacts::default();
1637        update
1638            .remove_incoming_requests
1639            .push(requester_id.to_proto());
1640        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1641            self.peer.send(connection_id, update.clone())?;
1642        }
1643
1644        response.send(proto::Ack {})?;
1645        Ok(())
1646    }
1647
1648    async fn update_diff_base(
1649        self: Arc<Server>,
1650        request: Message<proto::UpdateDiffBase>,
1651    ) -> Result<()> {
1652        let project_id = ProjectId::from_proto(request.payload.project_id);
1653        let project_connection_ids = self
1654            .app_state
1655            .db
1656            .project_connection_ids(project_id, request.sender_connection_id)
1657            .await?;
1658        broadcast(
1659            request.sender_connection_id,
1660            project_connection_ids,
1661            |connection_id| {
1662                self.peer.forward_send(
1663                    request.sender_connection_id,
1664                    connection_id,
1665                    request.payload.clone(),
1666                )
1667            },
1668        );
1669        Ok(())
1670    }
1671
1672    async fn get_private_user_info(
1673        self: Arc<Self>,
1674        request: Message<proto::GetPrivateUserInfo>,
1675        response: Response<proto::GetPrivateUserInfo>,
1676    ) -> Result<()> {
1677        let metrics_id = self
1678            .app_state
1679            .db
1680            .get_user_metrics_id(request.sender_user_id)
1681            .await?;
1682        let user = self
1683            .app_state
1684            .db
1685            .get_user_by_id(request.sender_user_id)
1686            .await?
1687            .ok_or_else(|| anyhow!("user not found"))?;
1688        response.send(proto::GetPrivateUserInfoResponse {
1689            metrics_id,
1690            staff: user.admin,
1691        })?;
1692        Ok(())
1693    }
1694
1695    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1696        #[cfg(test)]
1697        tokio::task::yield_now().await;
1698        let guard = self.store.lock().await;
1699        #[cfg(test)]
1700        tokio::task::yield_now().await;
1701        StoreGuard {
1702            guard,
1703            _not_send: PhantomData,
1704        }
1705    }
1706
1707    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1708        ServerSnapshot {
1709            store: self.store().await,
1710            peer: &self.peer,
1711        }
1712    }
1713}
1714
1715impl<'a> Deref for StoreGuard<'a> {
1716    type Target = Store;
1717
1718    fn deref(&self) -> &Self::Target {
1719        &*self.guard
1720    }
1721}
1722
1723impl<'a> DerefMut for StoreGuard<'a> {
1724    fn deref_mut(&mut self) -> &mut Self::Target {
1725        &mut *self.guard
1726    }
1727}
1728
1729impl<'a> Drop for StoreGuard<'a> {
1730    fn drop(&mut self) {
1731        #[cfg(test)]
1732        self.check_invariants();
1733    }
1734}
1735
1736impl Executor for RealExecutor {
1737    type Sleep = Sleep;
1738
1739    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1740        tokio::task::spawn(future);
1741    }
1742
1743    fn sleep(&self, duration: Duration) -> Self::Sleep {
1744        tokio::time::sleep(duration)
1745    }
1746}
1747
1748fn broadcast<F>(
1749    sender_id: ConnectionId,
1750    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1751    mut f: F,
1752) where
1753    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1754{
1755    for receiver_id in receiver_ids {
1756        if receiver_id != sender_id {
1757            f(receiver_id).trace_err();
1758        }
1759    }
1760}
1761
1762lazy_static! {
1763    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1764}
1765
1766pub struct ProtocolVersion(u32);
1767
1768impl Header for ProtocolVersion {
1769    fn name() -> &'static HeaderName {
1770        &ZED_PROTOCOL_VERSION
1771    }
1772
1773    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1774    where
1775        Self: Sized,
1776        I: Iterator<Item = &'i axum::http::HeaderValue>,
1777    {
1778        let version = values
1779            .next()
1780            .ok_or_else(axum::headers::Error::invalid)?
1781            .to_str()
1782            .map_err(|_| axum::headers::Error::invalid())?
1783            .parse()
1784            .map_err(|_| axum::headers::Error::invalid())?;
1785        Ok(Self(version))
1786    }
1787
1788    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1789        values.extend([self.0.to_string().parse().unwrap()]);
1790    }
1791}
1792
1793pub fn routes(server: Arc<Server>) -> Router<Body> {
1794    Router::new()
1795        .route("/rpc", get(handle_websocket_request))
1796        .layer(
1797            ServiceBuilder::new()
1798                .layer(Extension(server.app_state.clone()))
1799                .layer(middleware::from_fn(auth::validate_header)),
1800        )
1801        .route("/metrics", get(handle_metrics))
1802        .layer(Extension(server))
1803}
1804
1805pub async fn handle_websocket_request(
1806    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1807    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1808    Extension(server): Extension<Arc<Server>>,
1809    Extension(user): Extension<User>,
1810    ws: WebSocketUpgrade,
1811) -> axum::response::Response {
1812    if protocol_version != rpc::PROTOCOL_VERSION {
1813        return (
1814            StatusCode::UPGRADE_REQUIRED,
1815            "client must be upgraded".to_string(),
1816        )
1817            .into_response();
1818    }
1819    let socket_address = socket_address.to_string();
1820    ws.on_upgrade(move |socket| {
1821        use util::ResultExt;
1822        let socket = socket
1823            .map_ok(to_tungstenite_message)
1824            .err_into()
1825            .with(|message| async move { Ok(to_axum_message(message)) });
1826        let connection = Connection::new(Box::pin(socket));
1827        async move {
1828            server
1829                .handle_connection(connection, socket_address, user, None, RealExecutor)
1830                .await
1831                .log_err();
1832        }
1833    })
1834}
1835
1836pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1837    let metrics = server.store().await.metrics();
1838    METRIC_CONNECTIONS.set(metrics.connections as _);
1839    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1840
1841    let encoder = prometheus::TextEncoder::new();
1842    let metric_families = prometheus::gather();
1843    match encoder.encode_to_string(&metric_families) {
1844        Ok(string) => (StatusCode::OK, string).into_response(),
1845        Err(error) => (
1846            StatusCode::INTERNAL_SERVER_ERROR,
1847            format!("failed to encode metrics {:?}", error),
1848        )
1849            .into_response(),
1850    }
1851}
1852
1853fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1854    match message {
1855        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1856        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1857        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1858        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1859        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1860            code: frame.code.into(),
1861            reason: frame.reason,
1862        })),
1863    }
1864}
1865
1866fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1867    match message {
1868        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1869        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1870        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1871        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1872        AxumMessage::Close(frame) => {
1873            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1874                code: frame.code.into(),
1875                reason: frame.reason,
1876            }))
1877        }
1878    }
1879}
1880
1881pub trait ResultExt {
1882    type Ok;
1883
1884    fn trace_err(self) -> Option<Self::Ok>;
1885}
1886
1887impl<T, E> ResultExt for Result<T, E>
1888where
1889    E: std::fmt::Debug,
1890{
1891    type Ok = T;
1892
1893    fn trace_err(self) -> Option<T> {
1894        match self {
1895            Ok(value) => Some(value),
1896            Err(error) => {
1897                tracing::error!("{:?}", error);
1898                None
1899            }
1900        }
1901    }
1902}