rpc.rs

   1mod store;
   2
   3use crate::{
   4    auth,
   5    db::{self, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26use futures::{
  27    channel::oneshot,
  28    future::{self, BoxFuture},
  29    stream::FuturesUnordered,
  30    FutureExt, SinkExt, StreamExt, TryStreamExt,
  31};
  32use lazy_static::lazy_static;
  33use prometheus::{register_int_gauge, IntGauge};
  34use rpc::{
  35    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  36    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  37};
  38use serde::{Serialize, Serializer};
  39use std::{
  40    any::TypeId,
  41    future::Future,
  42    marker::PhantomData,
  43    net::SocketAddr,
  44    ops::{Deref, DerefMut},
  45    rc::Rc,
  46    sync::{
  47        atomic::{AtomicBool, Ordering::SeqCst},
  48        Arc,
  49    },
  50    time::Duration,
  51};
  52pub use store::{Store, Worktree};
  53use tokio::{
  54    sync::{Mutex, MutexGuard},
  55    time::Sleep,
  56};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60lazy_static! {
  61    static ref METRIC_CONNECTIONS: IntGauge =
  62        register_int_gauge!("connections", "number of connections").unwrap();
  63    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  64        "shared_projects",
  65        "number of open projects with one or more guests"
  66    )
  67    .unwrap();
  68}
  69
  70type MessageHandler = Box<
  71    dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
  72>;
  73
  74struct Message<T> {
  75    sender_user_id: UserId,
  76    sender_connection_id: ConnectionId,
  77    payload: T,
  78}
  79
  80struct Response<R> {
  81    server: Arc<Server>,
  82    receipt: Receipt<R>,
  83    responded: Arc<AtomicBool>,
  84}
  85
  86impl<R: RequestMessage> Response<R> {
  87    fn send(self, payload: R::Response) -> Result<()> {
  88        self.responded.store(true, SeqCst);
  89        self.server.peer.respond(self.receipt, payload)?;
  90        Ok(())
  91    }
  92}
  93
  94pub struct Server {
  95    peer: Arc<Peer>,
  96    pub(crate) store: Mutex<Store>,
  97    app_state: Arc<AppState>,
  98    handlers: HashMap<TypeId, MessageHandler>,
  99}
 100
 101pub trait Executor: Send + Clone {
 102    type Sleep: Send + Future;
 103    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 104    fn sleep(&self, duration: Duration) -> Self::Sleep;
 105}
 106
 107#[derive(Clone)]
 108pub struct RealExecutor;
 109
 110pub(crate) struct StoreGuard<'a> {
 111    guard: MutexGuard<'a, Store>,
 112    _not_send: PhantomData<Rc<()>>,
 113}
 114
 115#[derive(Serialize)]
 116pub struct ServerSnapshot<'a> {
 117    peer: &'a Peer,
 118    #[serde(serialize_with = "serialize_deref")]
 119    store: StoreGuard<'a>,
 120}
 121
 122pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 123where
 124    S: Serializer,
 125    T: Deref<Target = U>,
 126    U: Serialize,
 127{
 128    Serialize::serialize(value.deref(), serializer)
 129}
 130
 131impl Server {
 132    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 133        let mut server = Self {
 134            peer: Peer::new(),
 135            app_state,
 136            store: Default::default(),
 137            handlers: Default::default(),
 138        };
 139
 140        server
 141            .add_request_handler(Server::ping)
 142            .add_request_handler(Server::create_room)
 143            .add_request_handler(Server::join_room)
 144            .add_message_handler(Server::leave_room)
 145            .add_request_handler(Server::call)
 146            .add_request_handler(Server::cancel_call)
 147            .add_message_handler(Server::decline_call)
 148            .add_request_handler(Server::update_participant_location)
 149            .add_request_handler(Server::share_project)
 150            .add_message_handler(Server::unshare_project)
 151            .add_request_handler(Server::join_project)
 152            .add_message_handler(Server::leave_project)
 153            .add_request_handler(Server::update_project)
 154            .add_request_handler(Server::update_worktree)
 155            .add_message_handler(Server::start_language_server)
 156            .add_message_handler(Server::update_language_server)
 157            .add_message_handler(Server::update_diagnostic_summary)
 158            .add_request_handler(Server::forward_project_request::<proto::GetHover>)
 159            .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
 160            .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
 161            .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
 162            .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
 163            .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
 164            .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
 165            .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
 166            .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
 167            .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
 168            .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
 169            .add_request_handler(
 170                Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
 171            )
 172            .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
 173            .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
 174            .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
 175            .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
 176            .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
 177            .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
 178            .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
 179            .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
 180            .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
 181            .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
 182            .add_message_handler(Server::create_buffer_for_peer)
 183            .add_request_handler(Server::update_buffer)
 184            .add_message_handler(Server::update_buffer_file)
 185            .add_message_handler(Server::buffer_reloaded)
 186            .add_message_handler(Server::buffer_saved)
 187            .add_request_handler(Server::save_buffer)
 188            .add_request_handler(Server::get_users)
 189            .add_request_handler(Server::fuzzy_search_users)
 190            .add_request_handler(Server::request_contact)
 191            .add_request_handler(Server::remove_contact)
 192            .add_request_handler(Server::respond_to_contact_request)
 193            .add_request_handler(Server::follow)
 194            .add_message_handler(Server::unfollow)
 195            .add_message_handler(Server::update_followers)
 196            .add_message_handler(Server::update_diff_base)
 197            .add_request_handler(Server::get_private_user_info);
 198
 199        Arc::new(server)
 200    }
 201
 202    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 203    where
 204        F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
 205        Fut: 'static + Send + Future<Output = Result<()>>,
 206        M: EnvelopedMessage,
 207    {
 208        let prev_handler = self.handlers.insert(
 209            TypeId::of::<M>(),
 210            Box::new(move |server, sender_user_id, envelope| {
 211                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 212                let span = info_span!(
 213                    "handle message",
 214                    payload_type = envelope.payload_type_name()
 215                );
 216                span.in_scope(|| {
 217                    tracing::info!(
 218                        payload_type = envelope.payload_type_name(),
 219                        "message received"
 220                    );
 221                });
 222                let future = (handler)(server, sender_user_id, *envelope);
 223                async move {
 224                    if let Err(error) = future.await {
 225                        tracing::error!(%error, "error handling message");
 226                    }
 227                }
 228                .instrument(span)
 229                .boxed()
 230            }),
 231        );
 232        if prev_handler.is_some() {
 233            panic!("registered a handler for the same message twice");
 234        }
 235        self
 236    }
 237
 238    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 239    where
 240        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
 241        Fut: 'static + Send + Future<Output = Result<()>>,
 242        M: EnvelopedMessage,
 243    {
 244        self.add_handler(move |server, sender_user_id, envelope| {
 245            handler(
 246                server,
 247                Message {
 248                    sender_user_id,
 249                    sender_connection_id: envelope.sender_id,
 250                    payload: envelope.payload,
 251                },
 252            )
 253        });
 254        self
 255    }
 256
 257    /// Handle a request while holding a lock to the store. This is useful when we're registering
 258    /// a connection but we want to respond on the connection before anybody else can send on it.
 259    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 260    where
 261        F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
 262        Fut: Send + Future<Output = Result<()>>,
 263        M: RequestMessage,
 264    {
 265        let handler = Arc::new(handler);
 266        self.add_handler(move |server, sender_user_id, envelope| {
 267            let receipt = envelope.receipt();
 268            let handler = handler.clone();
 269            async move {
 270                let request = Message {
 271                    sender_user_id,
 272                    sender_connection_id: envelope.sender_id,
 273                    payload: envelope.payload,
 274                };
 275                let responded = Arc::new(AtomicBool::default());
 276                let response = Response {
 277                    server: server.clone(),
 278                    responded: responded.clone(),
 279                    receipt,
 280                };
 281                match (handler)(server.clone(), request, response).await {
 282                    Ok(()) => {
 283                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 284                            Ok(())
 285                        } else {
 286                            Err(anyhow!("handler did not send a response"))?
 287                        }
 288                    }
 289                    Err(error) => {
 290                        server.peer.respond_with_error(
 291                            receipt,
 292                            proto::Error {
 293                                message: error.to_string(),
 294                            },
 295                        )?;
 296                        Err(error)
 297                    }
 298                }
 299            }
 300        })
 301    }
 302
 303    pub fn handle_connection<E: Executor>(
 304        self: &Arc<Self>,
 305        connection: Connection,
 306        address: String,
 307        user: User,
 308        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 309        executor: E,
 310    ) -> impl Future<Output = Result<()>> {
 311        let mut this = self.clone();
 312        let user_id = user.id;
 313        let login = user.github_login;
 314        let span = info_span!("handle connection", %user_id, %login, %address);
 315        async move {
 316            let (connection_id, handle_io, mut incoming_rx) = this
 317                .peer
 318                .add_connection(connection, {
 319                    let executor = executor.clone();
 320                    move |duration| {
 321                        let timer = executor.sleep(duration);
 322                        async move {
 323                            timer.await;
 324                        }
 325                    }
 326                });
 327
 328            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 329            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 330            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 331
 332            if let Some(send_connection_id) = send_connection_id.take() {
 333                let _ = send_connection_id.send(connection_id);
 334            }
 335
 336            if !user.connected_once {
 337                this.peer.send(connection_id, proto::ShowContacts {})?;
 338                this.app_state.db.set_user_connected_once(user_id, true).await?;
 339            }
 340
 341            let (contacts, invite_code) = future::try_join(
 342                this.app_state.db.get_contacts(user_id),
 343                this.app_state.db.get_invite_code_for_user(user_id)
 344            ).await?;
 345
 346            {
 347                let mut store = this.store().await;
 348                store.add_connection(connection_id, user_id, user.admin);
 349                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 350
 351                if let Some((code, count)) = invite_code {
 352                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 353                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 354                        count,
 355                    })?;
 356                }
 357            }
 358
 359            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 360                this.peer.send(connection_id, incoming_call)?;
 361            }
 362
 363            this.update_user_contacts(user_id).await?;
 364
 365            let handle_io = handle_io.fuse();
 366            futures::pin_mut!(handle_io);
 367
 368            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 369            // This prevents deadlocks when e.g., client A performs a request to client B and
 370            // client B performs a request to client A. If both clients stop processing further
 371            // messages until their respective request completes, they won't have a chance to
 372            // respond to the other client's request and cause a deadlock.
 373            //
 374            // This arrangement ensures we will attempt to process earlier messages first, but fall
 375            // back to processing messages arrived later in the spirit of making progress.
 376            let mut foreground_message_handlers = FuturesUnordered::new();
 377            loop {
 378                let next_message = incoming_rx.next().fuse();
 379                futures::pin_mut!(next_message);
 380                futures::select_biased! {
 381                    result = handle_io => {
 382                        if let Err(error) = result {
 383                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 384                        }
 385                        break;
 386                    }
 387                    _ = foreground_message_handlers.next() => {}
 388                    message = next_message => {
 389                        if let Some(message) = message {
 390                            let type_name = message.payload_type_name();
 391                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 392                            let span_enter = span.enter();
 393                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 394                                let is_background = message.is_background();
 395                                let handle_message = (handler)(this.clone(), user_id, message);
 396                                drop(span_enter);
 397
 398                                let handle_message = handle_message.instrument(span);
 399                                if is_background {
 400                                    executor.spawn_detached(handle_message);
 401                                } else {
 402                                    foreground_message_handlers.push(handle_message);
 403                                }
 404                            } else {
 405                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 406                            }
 407                        } else {
 408                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 409                            break;
 410                        }
 411                    }
 412                }
 413            }
 414
 415            drop(foreground_message_handlers);
 416            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 417            if let Err(error) = this.sign_out(connection_id, user_id).await {
 418                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 419            }
 420
 421            Ok(())
 422        }.instrument(span)
 423    }
 424
 425    #[instrument(skip(self), err)]
 426    async fn sign_out(
 427        self: &mut Arc<Self>,
 428        connection_id: ConnectionId,
 429        user_id: UserId,
 430    ) -> Result<()> {
 431        self.peer.disconnect(connection_id);
 432        let decline_calls = {
 433            let mut store = self.store().await;
 434            store.remove_connection(connection_id)?;
 435            let mut connections = store.connection_ids_for_user(user_id);
 436            connections.next().is_none()
 437        };
 438
 439        self.leave_room_for_connection(connection_id, user_id)
 440            .await
 441            .trace_err();
 442        if decline_calls {
 443            if let Some(room) = self
 444                .app_state
 445                .db
 446                .decline_call(None, user_id)
 447                .await
 448                .trace_err()
 449            {
 450                self.room_updated(&room);
 451            }
 452        }
 453
 454        self.update_user_contacts(user_id).await?;
 455
 456        Ok(())
 457    }
 458
 459    pub async fn invite_code_redeemed(
 460        self: &Arc<Self>,
 461        inviter_id: UserId,
 462        invitee_id: UserId,
 463    ) -> Result<()> {
 464        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 465            if let Some(code) = &user.invite_code {
 466                let store = self.store().await;
 467                let invitee_contact = store.contact_for_user(invitee_id, true, false);
 468                for connection_id in store.connection_ids_for_user(inviter_id) {
 469                    self.peer.send(
 470                        connection_id,
 471                        proto::UpdateContacts {
 472                            contacts: vec![invitee_contact.clone()],
 473                            ..Default::default()
 474                        },
 475                    )?;
 476                    self.peer.send(
 477                        connection_id,
 478                        proto::UpdateInviteInfo {
 479                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 480                            count: user.invite_count as u32,
 481                        },
 482                    )?;
 483                }
 484            }
 485        }
 486        Ok(())
 487    }
 488
 489    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 490        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 491            if let Some(invite_code) = &user.invite_code {
 492                let store = self.store().await;
 493                for connection_id in store.connection_ids_for_user(user_id) {
 494                    self.peer.send(
 495                        connection_id,
 496                        proto::UpdateInviteInfo {
 497                            url: format!(
 498                                "{}{}",
 499                                self.app_state.config.invite_link_prefix, invite_code
 500                            ),
 501                            count: user.invite_count as u32,
 502                        },
 503                    )?;
 504                }
 505            }
 506        }
 507        Ok(())
 508    }
 509
 510    async fn ping(
 511        self: Arc<Server>,
 512        _: Message<proto::Ping>,
 513        response: Response<proto::Ping>,
 514    ) -> Result<()> {
 515        response.send(proto::Ack {})?;
 516        Ok(())
 517    }
 518
 519    async fn create_room(
 520        self: Arc<Server>,
 521        request: Message<proto::CreateRoom>,
 522        response: Response<proto::CreateRoom>,
 523    ) -> Result<()> {
 524        let room = self
 525            .app_state
 526            .db
 527            .create_room(request.sender_user_id, request.sender_connection_id)
 528            .await?;
 529
 530        let live_kit_connection_info =
 531            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 532                if let Some(_) = live_kit
 533                    .create_room(room.live_kit_room.clone())
 534                    .await
 535                    .trace_err()
 536                {
 537                    if let Some(token) = live_kit
 538                        .room_token(
 539                            &room.live_kit_room,
 540                            &request.sender_connection_id.to_string(),
 541                        )
 542                        .trace_err()
 543                    {
 544                        Some(proto::LiveKitConnectionInfo {
 545                            server_url: live_kit.url().into(),
 546                            token,
 547                        })
 548                    } else {
 549                        None
 550                    }
 551                } else {
 552                    None
 553                }
 554            } else {
 555                None
 556            };
 557
 558        response.send(proto::CreateRoomResponse {
 559            room: Some(room),
 560            live_kit_connection_info,
 561        })?;
 562        self.update_user_contacts(request.sender_user_id).await?;
 563        Ok(())
 564    }
 565
 566    async fn join_room(
 567        self: Arc<Server>,
 568        request: Message<proto::JoinRoom>,
 569        response: Response<proto::JoinRoom>,
 570    ) -> Result<()> {
 571        let room = self
 572            .app_state
 573            .db
 574            .join_room(
 575                RoomId::from_proto(request.payload.id),
 576                request.sender_user_id,
 577                request.sender_connection_id,
 578            )
 579            .await?;
 580        for connection_id in self
 581            .store()
 582            .await
 583            .connection_ids_for_user(request.sender_user_id)
 584        {
 585            self.peer
 586                .send(connection_id, proto::CallCanceled {})
 587                .trace_err();
 588        }
 589
 590        let live_kit_connection_info =
 591            if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 592                if let Some(token) = live_kit
 593                    .room_token(
 594                        &room.live_kit_room,
 595                        &request.sender_connection_id.to_string(),
 596                    )
 597                    .trace_err()
 598                {
 599                    Some(proto::LiveKitConnectionInfo {
 600                        server_url: live_kit.url().into(),
 601                        token,
 602                    })
 603                } else {
 604                    None
 605                }
 606            } else {
 607                None
 608            };
 609
 610        self.room_updated(&room);
 611        response.send(proto::JoinRoomResponse {
 612            room: Some(room),
 613            live_kit_connection_info,
 614        })?;
 615
 616        self.update_user_contacts(request.sender_user_id).await?;
 617        Ok(())
 618    }
 619
 620    async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
 621        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
 622            .await
 623    }
 624
 625    async fn leave_room_for_connection(
 626        self: &Arc<Server>,
 627        connection_id: ConnectionId,
 628        user_id: UserId,
 629    ) -> Result<()> {
 630        let mut contacts_to_update = HashSet::default();
 631
 632        let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else {
 633            return Err(anyhow!("no room to leave"))?;
 634        };
 635        contacts_to_update.insert(user_id);
 636
 637        for project in left_room.left_projects.into_values() {
 638            if project.host_user_id == user_id {
 639                for connection_id in project.connection_ids {
 640                    self.peer
 641                        .send(
 642                            connection_id,
 643                            proto::UnshareProject {
 644                                project_id: project.id.to_proto(),
 645                            },
 646                        )
 647                        .trace_err();
 648                }
 649            } else {
 650                for connection_id in project.connection_ids {
 651                    self.peer
 652                        .send(
 653                            connection_id,
 654                            proto::RemoveProjectCollaborator {
 655                                project_id: project.id.to_proto(),
 656                                peer_id: connection_id.0,
 657                            },
 658                        )
 659                        .trace_err();
 660                }
 661
 662                self.peer
 663                    .send(
 664                        connection_id,
 665                        proto::UnshareProject {
 666                            project_id: project.id.to_proto(),
 667                        },
 668                    )
 669                    .trace_err();
 670            }
 671        }
 672
 673        self.room_updated(&left_room.room);
 674        {
 675            let store = self.store().await;
 676            for canceled_user_id in left_room.canceled_calls_to_user_ids {
 677                for connection_id in store.connection_ids_for_user(canceled_user_id) {
 678                    self.peer
 679                        .send(connection_id, proto::CallCanceled {})
 680                        .trace_err();
 681                }
 682                contacts_to_update.insert(canceled_user_id);
 683            }
 684        }
 685
 686        for contact_user_id in contacts_to_update {
 687            self.update_user_contacts(contact_user_id).await?;
 688        }
 689
 690        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
 691            live_kit
 692                .remove_participant(
 693                    left_room.room.live_kit_room.clone(),
 694                    connection_id.to_string(),
 695                )
 696                .await
 697                .trace_err();
 698
 699            if left_room.room.participants.is_empty() {
 700                live_kit
 701                    .delete_room(left_room.room.live_kit_room)
 702                    .await
 703                    .trace_err();
 704            }
 705        }
 706
 707        Ok(())
 708    }
 709
 710    async fn call(
 711        self: Arc<Server>,
 712        request: Message<proto::Call>,
 713        response: Response<proto::Call>,
 714    ) -> Result<()> {
 715        let room_id = RoomId::from_proto(request.payload.room_id);
 716        let calling_user_id = request.sender_user_id;
 717        let calling_connection_id = request.sender_connection_id;
 718        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 719        let initial_project_id = request
 720            .payload
 721            .initial_project_id
 722            .map(ProjectId::from_proto);
 723        if !self
 724            .app_state
 725            .db
 726            .has_contact(calling_user_id, called_user_id)
 727            .await?
 728        {
 729            return Err(anyhow!("cannot call a user who isn't a contact"))?;
 730        }
 731
 732        let (room, incoming_call) = self
 733            .app_state
 734            .db
 735            .call(
 736                room_id,
 737                calling_user_id,
 738                calling_connection_id,
 739                called_user_id,
 740                initial_project_id,
 741            )
 742            .await?;
 743        self.room_updated(&room);
 744        self.update_user_contacts(called_user_id).await?;
 745
 746        let mut calls = self
 747            .store()
 748            .await
 749            .connection_ids_for_user(called_user_id)
 750            .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
 751            .collect::<FuturesUnordered<_>>();
 752
 753        while let Some(call_response) = calls.next().await {
 754            match call_response.as_ref() {
 755                Ok(_) => {
 756                    response.send(proto::Ack {})?;
 757                    return Ok(());
 758                }
 759                Err(_) => {
 760                    call_response.trace_err();
 761                }
 762            }
 763        }
 764
 765        let room = self
 766            .app_state
 767            .db
 768            .call_failed(room_id, called_user_id)
 769            .await?;
 770        self.room_updated(&room);
 771        self.update_user_contacts(called_user_id).await?;
 772
 773        Err(anyhow!("failed to ring user"))?
 774    }
 775
 776    async fn cancel_call(
 777        self: Arc<Server>,
 778        request: Message<proto::CancelCall>,
 779        response: Response<proto::CancelCall>,
 780    ) -> Result<()> {
 781        let called_user_id = UserId::from_proto(request.payload.called_user_id);
 782        let room_id = RoomId::from_proto(request.payload.room_id);
 783        let room = self
 784            .app_state
 785            .db
 786            .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
 787            .await?;
 788        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
 789            self.peer
 790                .send(connection_id, proto::CallCanceled {})
 791                .trace_err();
 792        }
 793        self.room_updated(&room);
 794        response.send(proto::Ack {})?;
 795
 796        self.update_user_contacts(called_user_id).await?;
 797        Ok(())
 798    }
 799
 800    async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
 801        let room_id = RoomId::from_proto(message.payload.room_id);
 802        let room = self
 803            .app_state
 804            .db
 805            .decline_call(Some(room_id), message.sender_user_id)
 806            .await?;
 807        for connection_id in self
 808            .store()
 809            .await
 810            .connection_ids_for_user(message.sender_user_id)
 811        {
 812            self.peer
 813                .send(connection_id, proto::CallCanceled {})
 814                .trace_err();
 815        }
 816        self.room_updated(&room);
 817        self.update_user_contacts(message.sender_user_id).await?;
 818        Ok(())
 819    }
 820
 821    async fn update_participant_location(
 822        self: Arc<Server>,
 823        request: Message<proto::UpdateParticipantLocation>,
 824        response: Response<proto::UpdateParticipantLocation>,
 825    ) -> Result<()> {
 826        let room_id = RoomId::from_proto(request.payload.room_id);
 827        let location = request
 828            .payload
 829            .location
 830            .ok_or_else(|| anyhow!("invalid location"))?;
 831        let room = self
 832            .app_state
 833            .db
 834            .update_room_participant_location(room_id, request.sender_connection_id, location)
 835            .await?;
 836        self.room_updated(&room);
 837        response.send(proto::Ack {})?;
 838        Ok(())
 839    }
 840
 841    fn room_updated(&self, room: &proto::Room) {
 842        for participant in &room.participants {
 843            self.peer
 844                .send(
 845                    ConnectionId(participant.peer_id),
 846                    proto::RoomUpdated {
 847                        room: Some(room.clone()),
 848                    },
 849                )
 850                .trace_err();
 851        }
 852    }
 853
 854    async fn share_project(
 855        self: Arc<Server>,
 856        request: Message<proto::ShareProject>,
 857        response: Response<proto::ShareProject>,
 858    ) -> Result<()> {
 859        let (project_id, room) = self
 860            .app_state
 861            .db
 862            .share_project(
 863                RoomId::from_proto(request.payload.room_id),
 864                request.sender_connection_id,
 865                &request.payload.worktrees,
 866            )
 867            .await
 868            .unwrap();
 869        response.send(proto::ShareProjectResponse {
 870            project_id: project_id.to_proto(),
 871        })?;
 872        self.room_updated(&room);
 873
 874        Ok(())
 875    }
 876
 877    async fn unshare_project(
 878        self: Arc<Server>,
 879        message: Message<proto::UnshareProject>,
 880    ) -> Result<()> {
 881        let project_id = ProjectId::from_proto(message.payload.project_id);
 882        let mut store = self.store().await;
 883        let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?;
 884        broadcast(
 885            message.sender_connection_id,
 886            project.guest_connection_ids(),
 887            |conn_id| self.peer.send(conn_id, message.payload.clone()),
 888        );
 889        self.room_updated(room);
 890
 891        Ok(())
 892    }
 893
 894    async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
 895        let contacts = self.app_state.db.get_contacts(user_id).await?;
 896        let busy = self.app_state.db.is_user_busy(user_id).await?;
 897        let store = self.store().await;
 898        let updated_contact = store.contact_for_user(user_id, false, busy);
 899        for contact in contacts {
 900            if let db::Contact::Accepted {
 901                user_id: contact_user_id,
 902                ..
 903            } = contact
 904            {
 905                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
 906                    self.peer
 907                        .send(
 908                            contact_conn_id,
 909                            proto::UpdateContacts {
 910                                contacts: vec![updated_contact.clone()],
 911                                remove_contacts: Default::default(),
 912                                incoming_requests: Default::default(),
 913                                remove_incoming_requests: Default::default(),
 914                                outgoing_requests: Default::default(),
 915                                remove_outgoing_requests: Default::default(),
 916                            },
 917                        )
 918                        .trace_err();
 919                }
 920            }
 921        }
 922        Ok(())
 923    }
 924
 925    async fn join_project(
 926        self: Arc<Server>,
 927        request: Message<proto::JoinProject>,
 928        response: Response<proto::JoinProject>,
 929    ) -> Result<()> {
 930        let project_id = ProjectId::from_proto(request.payload.project_id);
 931        let guest_user_id = request.sender_user_id;
 932
 933        tracing::info!(%project_id, "join project");
 934
 935        let (project, replica_id) = self
 936            .app_state
 937            .db
 938            .join_project(project_id, request.sender_connection_id)
 939            .await?;
 940
 941        let collaborators = project
 942            .collaborators
 943            .iter()
 944            .map(|collaborator| proto::Collaborator {
 945                peer_id: collaborator.connection_id as u32,
 946                replica_id: collaborator.replica_id.0 as u32,
 947                user_id: collaborator.user_id.to_proto(),
 948            })
 949            .collect::<Vec<_>>();
 950        let worktrees = project
 951            .worktrees
 952            .iter()
 953            .map(|(id, worktree)| proto::WorktreeMetadata {
 954                id: id.to_proto(),
 955                root_name: worktree.root_name.clone(),
 956                visible: worktree.visible,
 957                abs_path: worktree.abs_path.clone(),
 958            })
 959            .collect::<Vec<_>>();
 960
 961        for collaborator in &project.collaborators {
 962            let connection_id = ConnectionId(collaborator.connection_id as u32);
 963            if connection_id != request.sender_connection_id {
 964                self.peer
 965                    .send(
 966                        connection_id,
 967                        proto::AddProjectCollaborator {
 968                            project_id: project_id.to_proto(),
 969                            collaborator: Some(proto::Collaborator {
 970                                peer_id: request.sender_connection_id.0,
 971                                replica_id: replica_id.0 as u32,
 972                                user_id: guest_user_id.to_proto(),
 973                            }),
 974                        },
 975                    )
 976                    .trace_err();
 977            }
 978        }
 979
 980        // First, we send the metadata associated with each worktree.
 981        response.send(proto::JoinProjectResponse {
 982            worktrees: worktrees.clone(),
 983            replica_id: replica_id.0 as u32,
 984            collaborators: collaborators.clone(),
 985            language_servers: project.language_servers.clone(),
 986        })?;
 987
 988        for (worktree_id, worktree) in project.worktrees {
 989            #[cfg(any(test, feature = "test-support"))]
 990            const MAX_CHUNK_SIZE: usize = 2;
 991            #[cfg(not(any(test, feature = "test-support")))]
 992            const MAX_CHUNK_SIZE: usize = 256;
 993
 994            // Stream this worktree's entries.
 995            let message = proto::UpdateWorktree {
 996                project_id: project_id.to_proto(),
 997                worktree_id: worktree_id.to_proto(),
 998                abs_path: worktree.abs_path.clone(),
 999                root_name: worktree.root_name,
1000                updated_entries: worktree.entries,
1001                removed_entries: Default::default(),
1002                scan_id: worktree.scan_id,
1003                is_last_update: worktree.is_complete,
1004            };
1005            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1006                self.peer
1007                    .send(request.sender_connection_id, update.clone())?;
1008            }
1009
1010            // Stream this worktree's diagnostics.
1011            for summary in worktree.diagnostic_summaries {
1012                self.peer.send(
1013                    request.sender_connection_id,
1014                    proto::UpdateDiagnosticSummary {
1015                        project_id: project_id.to_proto(),
1016                        worktree_id: worktree.id.to_proto(),
1017                        summary: Some(summary),
1018                    },
1019                )?;
1020            }
1021        }
1022
1023        for language_server in &project.language_servers {
1024            self.peer.send(
1025                request.sender_connection_id,
1026                proto::UpdateLanguageServer {
1027                    project_id: project_id.to_proto(),
1028                    language_server_id: language_server.id,
1029                    variant: Some(
1030                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1031                            proto::LspDiskBasedDiagnosticsUpdated {},
1032                        ),
1033                    ),
1034                },
1035            )?;
1036        }
1037
1038        Ok(())
1039    }
1040
1041    async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1042        let sender_id = request.sender_connection_id;
1043        let project_id = ProjectId::from_proto(request.payload.project_id);
1044        let project;
1045        {
1046            let mut store = self.store().await;
1047            project = store.leave_project(project_id, sender_id)?;
1048            tracing::info!(
1049                %project_id,
1050                host_user_id = %project.host_user_id,
1051                host_connection_id = %project.host_connection_id,
1052                "leave project"
1053            );
1054
1055            if project.remove_collaborator {
1056                broadcast(sender_id, project.connection_ids, |conn_id| {
1057                    self.peer.send(
1058                        conn_id,
1059                        proto::RemoveProjectCollaborator {
1060                            project_id: project_id.to_proto(),
1061                            peer_id: sender_id.0,
1062                        },
1063                    )
1064                });
1065            }
1066        }
1067
1068        Ok(())
1069    }
1070
1071    async fn update_project(
1072        self: Arc<Server>,
1073        request: Message<proto::UpdateProject>,
1074        response: Response<proto::UpdateProject>,
1075    ) -> Result<()> {
1076        let project_id = ProjectId::from_proto(request.payload.project_id);
1077        let (room, guest_connection_ids) = self
1078            .app_state
1079            .db
1080            .update_project(
1081                project_id,
1082                request.sender_connection_id,
1083                &request.payload.worktrees,
1084            )
1085            .await?;
1086        broadcast(
1087            request.sender_connection_id,
1088            guest_connection_ids,
1089            |connection_id| {
1090                self.peer.forward_send(
1091                    request.sender_connection_id,
1092                    connection_id,
1093                    request.payload.clone(),
1094                )
1095            },
1096        );
1097        self.room_updated(&room);
1098        response.send(proto::Ack {})?;
1099
1100        Ok(())
1101    }
1102
1103    async fn update_worktree(
1104        self: Arc<Server>,
1105        request: Message<proto::UpdateWorktree>,
1106        response: Response<proto::UpdateWorktree>,
1107    ) -> Result<()> {
1108        let connection_ids = self
1109            .app_state
1110            .db
1111            .update_worktree(&request.payload, request.sender_connection_id)
1112            .await?;
1113
1114        broadcast(
1115            request.sender_connection_id,
1116            connection_ids,
1117            |connection_id| {
1118                self.peer.forward_send(
1119                    request.sender_connection_id,
1120                    connection_id,
1121                    request.payload.clone(),
1122                )
1123            },
1124        );
1125        response.send(proto::Ack {})?;
1126        Ok(())
1127    }
1128
1129    async fn update_diagnostic_summary(
1130        self: Arc<Server>,
1131        request: Message<proto::UpdateDiagnosticSummary>,
1132    ) -> Result<()> {
1133        let summary = request
1134            .payload
1135            .summary
1136            .clone()
1137            .ok_or_else(|| anyhow!("invalid summary"))?;
1138        let receiver_ids = self.store().await.update_diagnostic_summary(
1139            ProjectId::from_proto(request.payload.project_id),
1140            request.payload.worktree_id,
1141            request.sender_connection_id,
1142            summary,
1143        )?;
1144
1145        broadcast(
1146            request.sender_connection_id,
1147            receiver_ids,
1148            |connection_id| {
1149                self.peer.forward_send(
1150                    request.sender_connection_id,
1151                    connection_id,
1152                    request.payload.clone(),
1153                )
1154            },
1155        );
1156        Ok(())
1157    }
1158
1159    async fn start_language_server(
1160        self: Arc<Server>,
1161        request: Message<proto::StartLanguageServer>,
1162    ) -> Result<()> {
1163        let receiver_ids = self.store().await.start_language_server(
1164            ProjectId::from_proto(request.payload.project_id),
1165            request.sender_connection_id,
1166            request
1167                .payload
1168                .server
1169                .clone()
1170                .ok_or_else(|| anyhow!("invalid language server"))?,
1171        )?;
1172        broadcast(
1173            request.sender_connection_id,
1174            receiver_ids,
1175            |connection_id| {
1176                self.peer.forward_send(
1177                    request.sender_connection_id,
1178                    connection_id,
1179                    request.payload.clone(),
1180                )
1181            },
1182        );
1183        Ok(())
1184    }
1185
1186    async fn update_language_server(
1187        self: Arc<Server>,
1188        request: Message<proto::UpdateLanguageServer>,
1189    ) -> Result<()> {
1190        let receiver_ids = self.store().await.project_connection_ids(
1191            ProjectId::from_proto(request.payload.project_id),
1192            request.sender_connection_id,
1193        )?;
1194        broadcast(
1195            request.sender_connection_id,
1196            receiver_ids,
1197            |connection_id| {
1198                self.peer.forward_send(
1199                    request.sender_connection_id,
1200                    connection_id,
1201                    request.payload.clone(),
1202                )
1203            },
1204        );
1205        Ok(())
1206    }
1207
1208    async fn forward_project_request<T>(
1209        self: Arc<Server>,
1210        request: Message<T>,
1211        response: Response<T>,
1212    ) -> Result<()>
1213    where
1214        T: EntityMessage + RequestMessage,
1215    {
1216        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1217        let host_connection_id = self
1218            .store()
1219            .await
1220            .read_project(project_id, request.sender_connection_id)?
1221            .host_connection_id;
1222        let payload = self
1223            .peer
1224            .forward_request(
1225                request.sender_connection_id,
1226                host_connection_id,
1227                request.payload,
1228            )
1229            .await?;
1230
1231        // Ensure project still exists by the time we get the response from the host.
1232        self.store()
1233            .await
1234            .read_project(project_id, request.sender_connection_id)?;
1235
1236        response.send(payload)?;
1237        Ok(())
1238    }
1239
1240    async fn save_buffer(
1241        self: Arc<Server>,
1242        request: Message<proto::SaveBuffer>,
1243        response: Response<proto::SaveBuffer>,
1244    ) -> Result<()> {
1245        let project_id = ProjectId::from_proto(request.payload.project_id);
1246        let host = self
1247            .store()
1248            .await
1249            .read_project(project_id, request.sender_connection_id)?
1250            .host_connection_id;
1251        let response_payload = self
1252            .peer
1253            .forward_request(request.sender_connection_id, host, request.payload.clone())
1254            .await?;
1255
1256        let mut guests = self
1257            .store()
1258            .await
1259            .read_project(project_id, request.sender_connection_id)?
1260            .connection_ids();
1261        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
1262        broadcast(host, guests, |conn_id| {
1263            self.peer
1264                .forward_send(host, conn_id, response_payload.clone())
1265        });
1266        response.send(response_payload)?;
1267        Ok(())
1268    }
1269
1270    async fn create_buffer_for_peer(
1271        self: Arc<Server>,
1272        request: Message<proto::CreateBufferForPeer>,
1273    ) -> Result<()> {
1274        self.peer.forward_send(
1275            request.sender_connection_id,
1276            ConnectionId(request.payload.peer_id),
1277            request.payload,
1278        )?;
1279        Ok(())
1280    }
1281
1282    async fn update_buffer(
1283        self: Arc<Server>,
1284        request: Message<proto::UpdateBuffer>,
1285        response: Response<proto::UpdateBuffer>,
1286    ) -> Result<()> {
1287        let project_id = ProjectId::from_proto(request.payload.project_id);
1288        let receiver_ids = {
1289            let store = self.store().await;
1290            store.project_connection_ids(project_id, request.sender_connection_id)?
1291        };
1292
1293        broadcast(
1294            request.sender_connection_id,
1295            receiver_ids,
1296            |connection_id| {
1297                self.peer.forward_send(
1298                    request.sender_connection_id,
1299                    connection_id,
1300                    request.payload.clone(),
1301                )
1302            },
1303        );
1304        response.send(proto::Ack {})?;
1305        Ok(())
1306    }
1307
1308    async fn update_buffer_file(
1309        self: Arc<Server>,
1310        request: Message<proto::UpdateBufferFile>,
1311    ) -> Result<()> {
1312        let receiver_ids = self.store().await.project_connection_ids(
1313            ProjectId::from_proto(request.payload.project_id),
1314            request.sender_connection_id,
1315        )?;
1316        broadcast(
1317            request.sender_connection_id,
1318            receiver_ids,
1319            |connection_id| {
1320                self.peer.forward_send(
1321                    request.sender_connection_id,
1322                    connection_id,
1323                    request.payload.clone(),
1324                )
1325            },
1326        );
1327        Ok(())
1328    }
1329
1330    async fn buffer_reloaded(
1331        self: Arc<Server>,
1332        request: Message<proto::BufferReloaded>,
1333    ) -> Result<()> {
1334        let receiver_ids = self.store().await.project_connection_ids(
1335            ProjectId::from_proto(request.payload.project_id),
1336            request.sender_connection_id,
1337        )?;
1338        broadcast(
1339            request.sender_connection_id,
1340            receiver_ids,
1341            |connection_id| {
1342                self.peer.forward_send(
1343                    request.sender_connection_id,
1344                    connection_id,
1345                    request.payload.clone(),
1346                )
1347            },
1348        );
1349        Ok(())
1350    }
1351
1352    async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1353        let receiver_ids = self.store().await.project_connection_ids(
1354            ProjectId::from_proto(request.payload.project_id),
1355            request.sender_connection_id,
1356        )?;
1357        broadcast(
1358            request.sender_connection_id,
1359            receiver_ids,
1360            |connection_id| {
1361                self.peer.forward_send(
1362                    request.sender_connection_id,
1363                    connection_id,
1364                    request.payload.clone(),
1365                )
1366            },
1367        );
1368        Ok(())
1369    }
1370
1371    async fn follow(
1372        self: Arc<Self>,
1373        request: Message<proto::Follow>,
1374        response: Response<proto::Follow>,
1375    ) -> Result<()> {
1376        let project_id = ProjectId::from_proto(request.payload.project_id);
1377        let leader_id = ConnectionId(request.payload.leader_id);
1378        let follower_id = request.sender_connection_id;
1379        {
1380            let store = self.store().await;
1381            if !store
1382                .project_connection_ids(project_id, follower_id)?
1383                .contains(&leader_id)
1384            {
1385                Err(anyhow!("no such peer"))?;
1386            }
1387        }
1388
1389        let mut response_payload = self
1390            .peer
1391            .forward_request(request.sender_connection_id, leader_id, request.payload)
1392            .await?;
1393        response_payload
1394            .views
1395            .retain(|view| view.leader_id != Some(follower_id.0));
1396        response.send(response_payload)?;
1397        Ok(())
1398    }
1399
1400    async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1401        let project_id = ProjectId::from_proto(request.payload.project_id);
1402        let leader_id = ConnectionId(request.payload.leader_id);
1403        let store = self.store().await;
1404        if !store
1405            .project_connection_ids(project_id, request.sender_connection_id)?
1406            .contains(&leader_id)
1407        {
1408            Err(anyhow!("no such peer"))?;
1409        }
1410        self.peer
1411            .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1412        Ok(())
1413    }
1414
1415    async fn update_followers(
1416        self: Arc<Self>,
1417        request: Message<proto::UpdateFollowers>,
1418    ) -> Result<()> {
1419        let project_id = ProjectId::from_proto(request.payload.project_id);
1420        let store = self.store().await;
1421        let connection_ids =
1422            store.project_connection_ids(project_id, request.sender_connection_id)?;
1423        let leader_id = request
1424            .payload
1425            .variant
1426            .as_ref()
1427            .and_then(|variant| match variant {
1428                proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1429                proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1430                proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1431            });
1432        for follower_id in &request.payload.follower_ids {
1433            let follower_id = ConnectionId(*follower_id);
1434            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1435                self.peer.forward_send(
1436                    request.sender_connection_id,
1437                    follower_id,
1438                    request.payload.clone(),
1439                )?;
1440            }
1441        }
1442        Ok(())
1443    }
1444
1445    async fn get_users(
1446        self: Arc<Server>,
1447        request: Message<proto::GetUsers>,
1448        response: Response<proto::GetUsers>,
1449    ) -> Result<()> {
1450        let user_ids = request
1451            .payload
1452            .user_ids
1453            .into_iter()
1454            .map(UserId::from_proto)
1455            .collect();
1456        let users = self
1457            .app_state
1458            .db
1459            .get_users_by_ids(user_ids)
1460            .await?
1461            .into_iter()
1462            .map(|user| proto::User {
1463                id: user.id.to_proto(),
1464                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1465                github_login: user.github_login,
1466            })
1467            .collect();
1468        response.send(proto::UsersResponse { users })?;
1469        Ok(())
1470    }
1471
1472    async fn fuzzy_search_users(
1473        self: Arc<Server>,
1474        request: Message<proto::FuzzySearchUsers>,
1475        response: Response<proto::FuzzySearchUsers>,
1476    ) -> Result<()> {
1477        let query = request.payload.query;
1478        let db = &self.app_state.db;
1479        let users = match query.len() {
1480            0 => vec![],
1481            1 | 2 => db
1482                .get_user_by_github_account(&query, None)
1483                .await?
1484                .into_iter()
1485                .collect(),
1486            _ => db.fuzzy_search_users(&query, 10).await?,
1487        };
1488        let users = users
1489            .into_iter()
1490            .filter(|user| user.id != request.sender_user_id)
1491            .map(|user| proto::User {
1492                id: user.id.to_proto(),
1493                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1494                github_login: user.github_login,
1495            })
1496            .collect();
1497        response.send(proto::UsersResponse { users })?;
1498        Ok(())
1499    }
1500
1501    async fn request_contact(
1502        self: Arc<Server>,
1503        request: Message<proto::RequestContact>,
1504        response: Response<proto::RequestContact>,
1505    ) -> Result<()> {
1506        let requester_id = request.sender_user_id;
1507        let responder_id = UserId::from_proto(request.payload.responder_id);
1508        if requester_id == responder_id {
1509            return Err(anyhow!("cannot add yourself as a contact"))?;
1510        }
1511
1512        self.app_state
1513            .db
1514            .send_contact_request(requester_id, responder_id)
1515            .await?;
1516
1517        // Update outgoing contact requests of requester
1518        let mut update = proto::UpdateContacts::default();
1519        update.outgoing_requests.push(responder_id.to_proto());
1520        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1521            self.peer.send(connection_id, update.clone())?;
1522        }
1523
1524        // Update incoming contact requests of responder
1525        let mut update = proto::UpdateContacts::default();
1526        update
1527            .incoming_requests
1528            .push(proto::IncomingContactRequest {
1529                requester_id: requester_id.to_proto(),
1530                should_notify: true,
1531            });
1532        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1533            self.peer.send(connection_id, update.clone())?;
1534        }
1535
1536        response.send(proto::Ack {})?;
1537        Ok(())
1538    }
1539
1540    async fn respond_to_contact_request(
1541        self: Arc<Server>,
1542        request: Message<proto::RespondToContactRequest>,
1543        response: Response<proto::RespondToContactRequest>,
1544    ) -> Result<()> {
1545        let responder_id = request.sender_user_id;
1546        let requester_id = UserId::from_proto(request.payload.requester_id);
1547        if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1548            self.app_state
1549                .db
1550                .dismiss_contact_notification(responder_id, requester_id)
1551                .await?;
1552        } else {
1553            let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1554            self.app_state
1555                .db
1556                .respond_to_contact_request(responder_id, requester_id, accept)
1557                .await?;
1558            let busy = self.app_state.db.is_user_busy(requester_id).await?;
1559
1560            let store = self.store().await;
1561            // Update responder with new contact
1562            let mut update = proto::UpdateContacts::default();
1563            if accept {
1564                update
1565                    .contacts
1566                    .push(store.contact_for_user(requester_id, false, busy));
1567            }
1568            update
1569                .remove_incoming_requests
1570                .push(requester_id.to_proto());
1571            for connection_id in store.connection_ids_for_user(responder_id) {
1572                self.peer.send(connection_id, update.clone())?;
1573            }
1574
1575            // Update requester with new contact
1576            let mut update = proto::UpdateContacts::default();
1577            if accept {
1578                update
1579                    .contacts
1580                    .push(store.contact_for_user(responder_id, true, busy));
1581            }
1582            update
1583                .remove_outgoing_requests
1584                .push(responder_id.to_proto());
1585            for connection_id in store.connection_ids_for_user(requester_id) {
1586                self.peer.send(connection_id, update.clone())?;
1587            }
1588        }
1589
1590        response.send(proto::Ack {})?;
1591        Ok(())
1592    }
1593
1594    async fn remove_contact(
1595        self: Arc<Server>,
1596        request: Message<proto::RemoveContact>,
1597        response: Response<proto::RemoveContact>,
1598    ) -> Result<()> {
1599        let requester_id = request.sender_user_id;
1600        let responder_id = UserId::from_proto(request.payload.user_id);
1601        self.app_state
1602            .db
1603            .remove_contact(requester_id, responder_id)
1604            .await?;
1605
1606        // Update outgoing contact requests of requester
1607        let mut update = proto::UpdateContacts::default();
1608        update
1609            .remove_outgoing_requests
1610            .push(responder_id.to_proto());
1611        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1612            self.peer.send(connection_id, update.clone())?;
1613        }
1614
1615        // Update incoming contact requests of responder
1616        let mut update = proto::UpdateContacts::default();
1617        update
1618            .remove_incoming_requests
1619            .push(requester_id.to_proto());
1620        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1621            self.peer.send(connection_id, update.clone())?;
1622        }
1623
1624        response.send(proto::Ack {})?;
1625        Ok(())
1626    }
1627
1628    async fn update_diff_base(
1629        self: Arc<Server>,
1630        request: Message<proto::UpdateDiffBase>,
1631    ) -> Result<()> {
1632        let receiver_ids = self.store().await.project_connection_ids(
1633            ProjectId::from_proto(request.payload.project_id),
1634            request.sender_connection_id,
1635        )?;
1636        broadcast(
1637            request.sender_connection_id,
1638            receiver_ids,
1639            |connection_id| {
1640                self.peer.forward_send(
1641                    request.sender_connection_id,
1642                    connection_id,
1643                    request.payload.clone(),
1644                )
1645            },
1646        );
1647        Ok(())
1648    }
1649
1650    async fn get_private_user_info(
1651        self: Arc<Self>,
1652        request: Message<proto::GetPrivateUserInfo>,
1653        response: Response<proto::GetPrivateUserInfo>,
1654    ) -> Result<()> {
1655        let metrics_id = self
1656            .app_state
1657            .db
1658            .get_user_metrics_id(request.sender_user_id)
1659            .await?;
1660        let user = self
1661            .app_state
1662            .db
1663            .get_user_by_id(request.sender_user_id)
1664            .await?
1665            .ok_or_else(|| anyhow!("user not found"))?;
1666        response.send(proto::GetPrivateUserInfoResponse {
1667            metrics_id,
1668            staff: user.admin,
1669        })?;
1670        Ok(())
1671    }
1672
1673    pub(crate) async fn store(&self) -> StoreGuard<'_> {
1674        #[cfg(test)]
1675        tokio::task::yield_now().await;
1676        let guard = self.store.lock().await;
1677        #[cfg(test)]
1678        tokio::task::yield_now().await;
1679        StoreGuard {
1680            guard,
1681            _not_send: PhantomData,
1682        }
1683    }
1684
1685    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1686        ServerSnapshot {
1687            store: self.store().await,
1688            peer: &self.peer,
1689        }
1690    }
1691}
1692
1693impl<'a> Deref for StoreGuard<'a> {
1694    type Target = Store;
1695
1696    fn deref(&self) -> &Self::Target {
1697        &*self.guard
1698    }
1699}
1700
1701impl<'a> DerefMut for StoreGuard<'a> {
1702    fn deref_mut(&mut self) -> &mut Self::Target {
1703        &mut *self.guard
1704    }
1705}
1706
1707impl<'a> Drop for StoreGuard<'a> {
1708    fn drop(&mut self) {
1709        #[cfg(test)]
1710        self.check_invariants();
1711    }
1712}
1713
1714impl Executor for RealExecutor {
1715    type Sleep = Sleep;
1716
1717    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1718        tokio::task::spawn(future);
1719    }
1720
1721    fn sleep(&self, duration: Duration) -> Self::Sleep {
1722        tokio::time::sleep(duration)
1723    }
1724}
1725
1726fn broadcast<F>(
1727    sender_id: ConnectionId,
1728    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1729    mut f: F,
1730) where
1731    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1732{
1733    for receiver_id in receiver_ids {
1734        if receiver_id != sender_id {
1735            f(receiver_id).trace_err();
1736        }
1737    }
1738}
1739
1740lazy_static! {
1741    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1742}
1743
1744pub struct ProtocolVersion(u32);
1745
1746impl Header for ProtocolVersion {
1747    fn name() -> &'static HeaderName {
1748        &ZED_PROTOCOL_VERSION
1749    }
1750
1751    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1752    where
1753        Self: Sized,
1754        I: Iterator<Item = &'i axum::http::HeaderValue>,
1755    {
1756        let version = values
1757            .next()
1758            .ok_or_else(axum::headers::Error::invalid)?
1759            .to_str()
1760            .map_err(|_| axum::headers::Error::invalid())?
1761            .parse()
1762            .map_err(|_| axum::headers::Error::invalid())?;
1763        Ok(Self(version))
1764    }
1765
1766    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1767        values.extend([self.0.to_string().parse().unwrap()]);
1768    }
1769}
1770
1771pub fn routes(server: Arc<Server>) -> Router<Body> {
1772    Router::new()
1773        .route("/rpc", get(handle_websocket_request))
1774        .layer(
1775            ServiceBuilder::new()
1776                .layer(Extension(server.app_state.clone()))
1777                .layer(middleware::from_fn(auth::validate_header)),
1778        )
1779        .route("/metrics", get(handle_metrics))
1780        .layer(Extension(server))
1781}
1782
1783pub async fn handle_websocket_request(
1784    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1785    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1786    Extension(server): Extension<Arc<Server>>,
1787    Extension(user): Extension<User>,
1788    ws: WebSocketUpgrade,
1789) -> axum::response::Response {
1790    if protocol_version != rpc::PROTOCOL_VERSION {
1791        return (
1792            StatusCode::UPGRADE_REQUIRED,
1793            "client must be upgraded".to_string(),
1794        )
1795            .into_response();
1796    }
1797    let socket_address = socket_address.to_string();
1798    ws.on_upgrade(move |socket| {
1799        use util::ResultExt;
1800        let socket = socket
1801            .map_ok(to_tungstenite_message)
1802            .err_into()
1803            .with(|message| async move { Ok(to_axum_message(message)) });
1804        let connection = Connection::new(Box::pin(socket));
1805        async move {
1806            server
1807                .handle_connection(connection, socket_address, user, None, RealExecutor)
1808                .await
1809                .log_err();
1810        }
1811    })
1812}
1813
1814pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1815    let metrics = server.store().await.metrics();
1816    METRIC_CONNECTIONS.set(metrics.connections as _);
1817    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1818
1819    let encoder = prometheus::TextEncoder::new();
1820    let metric_families = prometheus::gather();
1821    match encoder.encode_to_string(&metric_families) {
1822        Ok(string) => (StatusCode::OK, string).into_response(),
1823        Err(error) => (
1824            StatusCode::INTERNAL_SERVER_ERROR,
1825            format!("failed to encode metrics {:?}", error),
1826        )
1827            .into_response(),
1828    }
1829}
1830
1831fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1832    match message {
1833        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1834        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1835        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1836        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1837        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1838            code: frame.code.into(),
1839            reason: frame.reason,
1840        })),
1841    }
1842}
1843
1844fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1845    match message {
1846        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1847        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1848        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1849        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1850        AxumMessage::Close(frame) => {
1851            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1852                code: frame.code.into(),
1853                reason: frame.reason,
1854            }))
1855        }
1856    }
1857}
1858
1859pub trait ResultExt {
1860    type Ok;
1861
1862    fn trace_err(self) -> Option<Self::Ok>;
1863}
1864
1865impl<T, E> ResultExt for Result<T, E>
1866where
1867    E: std::fmt::Debug,
1868{
1869    type Ok = T;
1870
1871    fn trace_err(self) -> Option<T> {
1872        match self {
1873            Ok(value) => Some(value),
1874            Err(error) => {
1875                tracing::error!("{:?}", error);
1876                None
1877            }
1878        }
1879    }
1880}