rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    epoch: parking_lot::Mutex<ServerId>,
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    executor: Executor,
 146    handlers: HashMap<TypeId, MessageHandler>,
 147    teardown: watch::Sender<()>,
 148}
 149
 150pub(crate) struct ConnectionPoolGuard<'a> {
 151    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 152    _not_send: PhantomData<Rc<()>>,
 153}
 154
 155#[derive(Serialize)]
 156pub struct ServerSnapshot<'a> {
 157    peer: &'a Peer,
 158    #[serde(serialize_with = "serialize_deref")]
 159    connection_pool: ConnectionPoolGuard<'a>,
 160}
 161
 162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 163where
 164    S: Serializer,
 165    T: Deref<Target = U>,
 166    U: Serialize,
 167{
 168    Serialize::serialize(value.deref(), serializer)
 169}
 170
 171impl Server {
 172    pub fn new(epoch: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 173        let mut server = Self {
 174            epoch: parking_lot::Mutex::new(epoch),
 175            peer: Peer::new(epoch.0 as u32),
 176            app_state,
 177            executor,
 178            connection_pool: Default::default(),
 179            handlers: Default::default(),
 180            teardown: watch::channel(()).0,
 181        };
 182
 183        server
 184            .add_request_handler(ping)
 185            .add_request_handler(create_room)
 186            .add_request_handler(join_room)
 187            .add_message_handler(leave_room)
 188            .add_request_handler(call)
 189            .add_request_handler(cancel_call)
 190            .add_message_handler(decline_call)
 191            .add_request_handler(update_participant_location)
 192            .add_request_handler(share_project)
 193            .add_message_handler(unshare_project)
 194            .add_request_handler(join_project)
 195            .add_message_handler(leave_project)
 196            .add_request_handler(update_project)
 197            .add_request_handler(update_worktree)
 198            .add_message_handler(start_language_server)
 199            .add_message_handler(update_language_server)
 200            .add_message_handler(update_diagnostic_summary)
 201            .add_request_handler(forward_project_request::<proto::GetHover>)
 202            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 203            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 204            .add_request_handler(forward_project_request::<proto::GetReferences>)
 205            .add_request_handler(forward_project_request::<proto::SearchProject>)
 206            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 207            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 209            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 210            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 211            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 213            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 214            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 215            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 216            .add_request_handler(forward_project_request::<proto::PerformRename>)
 217            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 218            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 219            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 221            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 222            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 223            .add_message_handler(create_buffer_for_peer)
 224            .add_request_handler(update_buffer)
 225            .add_message_handler(update_buffer_file)
 226            .add_message_handler(buffer_reloaded)
 227            .add_message_handler(buffer_saved)
 228            .add_request_handler(save_buffer)
 229            .add_request_handler(get_users)
 230            .add_request_handler(fuzzy_search_users)
 231            .add_request_handler(request_contact)
 232            .add_request_handler(remove_contact)
 233            .add_request_handler(respond_to_contact_request)
 234            .add_request_handler(follow)
 235            .add_message_handler(unfollow)
 236            .add_message_handler(update_followers)
 237            .add_message_handler(update_diff_base)
 238            .add_request_handler(get_private_user_info);
 239
 240        Arc::new(server)
 241    }
 242
 243    pub async fn start(&self) -> Result<()> {
 244        let epoch = *self.epoch.lock();
 245        let app_state = self.app_state.clone();
 246        let peer = self.peer.clone();
 247        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 248        let pool = self.connection_pool.clone();
 249        let live_kit_client = self.app_state.live_kit_client.clone();
 250
 251        let span = info_span!("start server");
 252        let span_enter = span.enter();
 253
 254        tracing::info!("begin deleting stale projects");
 255        app_state
 256            .db
 257            .delete_stale_projects(epoch, &app_state.config.zed_environment)
 258            .await?;
 259        tracing::info!("finish deleting stale projects");
 260
 261        drop(span_enter);
 262        self.executor.spawn_detached(
 263            async move {
 264                tracing::info!("waiting for cleanup timeout");
 265                timeout.await;
 266                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 267                if let Some(room_ids) = app_state
 268                    .db
 269                    .stale_room_ids(epoch, &app_state.config.zed_environment)
 270                    .await
 271                    .trace_err()
 272                {
 273                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 274                    for room_id in room_ids {
 275                        let mut contacts_to_update = HashSet::default();
 276                        let mut canceled_calls_to_user_ids = Vec::new();
 277                        let mut live_kit_room = String::new();
 278                        let mut delete_live_kit_room = false;
 279
 280                        if let Ok(mut refreshed_room) =
 281                            app_state.db.refresh_room(room_id, epoch).await
 282                        {
 283                            tracing::info!(
 284                                room_id = room_id.0,
 285                                new_participant_count = refreshed_room.room.participants.len(),
 286                                "refreshed room"
 287                            );
 288                            room_updated(&refreshed_room.room, &peer);
 289                            contacts_to_update
 290                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 291                            contacts_to_update
 292                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 293                            canceled_calls_to_user_ids =
 294                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 295                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 296                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 297                        }
 298
 299                        {
 300                            let pool = pool.lock();
 301                            for canceled_user_id in canceled_calls_to_user_ids {
 302                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 303                                    peer.send(
 304                                        connection_id,
 305                                        proto::CallCanceled {
 306                                            room_id: room_id.to_proto(),
 307                                        },
 308                                    )
 309                                    .trace_err();
 310                                }
 311                            }
 312                        }
 313
 314                        for user_id in contacts_to_update {
 315                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 316                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 317                            if let Some((busy, contacts)) = busy.zip(contacts) {
 318                                let pool = pool.lock();
 319                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 320                                for contact in contacts {
 321                                    if let db::Contact::Accepted {
 322                                        user_id: contact_user_id,
 323                                        ..
 324                                    } = contact
 325                                    {
 326                                        for contact_conn_id in
 327                                            pool.user_connection_ids(contact_user_id)
 328                                        {
 329                                            peer.send(
 330                                                contact_conn_id,
 331                                                proto::UpdateContacts {
 332                                                    contacts: vec![updated_contact.clone()],
 333                                                    remove_contacts: Default::default(),
 334                                                    incoming_requests: Default::default(),
 335                                                    remove_incoming_requests: Default::default(),
 336                                                    outgoing_requests: Default::default(),
 337                                                    remove_outgoing_requests: Default::default(),
 338                                                },
 339                                            )
 340                                            .trace_err();
 341                                        }
 342                                    }
 343                                }
 344                            }
 345                        }
 346
 347                        if let Some(live_kit) = live_kit_client.as_ref() {
 348                            if delete_live_kit_room {
 349                                live_kit.delete_room(live_kit_room).await.trace_err();
 350                            }
 351                        }
 352                    }
 353                }
 354
 355                app_state
 356                    .db
 357                    .delete_stale_servers(epoch, &app_state.config.zed_environment)
 358                    .await
 359                    .trace_err();
 360            }
 361            .instrument(span),
 362        );
 363        Ok(())
 364    }
 365
 366    pub fn teardown(&self) {
 367        self.peer.teardown();
 368        self.connection_pool.lock().reset();
 369        let _ = self.teardown.send(());
 370    }
 371
 372    #[cfg(test)]
 373    pub fn reset(&self, epoch: ServerId) {
 374        self.teardown();
 375        *self.epoch.lock() = epoch;
 376        self.peer.reset(epoch.0 as u32);
 377    }
 378
 379    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 380    where
 381        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 382        Fut: 'static + Send + Future<Output = Result<()>>,
 383        M: EnvelopedMessage,
 384    {
 385        let prev_handler = self.handlers.insert(
 386            TypeId::of::<M>(),
 387            Box::new(move |envelope, session| {
 388                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 389                let span = info_span!(
 390                    "handle message",
 391                    payload_type = envelope.payload_type_name()
 392                );
 393                span.in_scope(|| {
 394                    tracing::info!(
 395                        payload_type = envelope.payload_type_name(),
 396                        "message received"
 397                    );
 398                });
 399                let future = (handler)(*envelope, session);
 400                async move {
 401                    if let Err(error) = future.await {
 402                        tracing::error!(%error, "error handling message");
 403                    }
 404                }
 405                .instrument(span)
 406                .boxed()
 407            }),
 408        );
 409        if prev_handler.is_some() {
 410            panic!("registered a handler for the same message twice");
 411        }
 412        self
 413    }
 414
 415    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 416    where
 417        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 418        Fut: 'static + Send + Future<Output = Result<()>>,
 419        M: EnvelopedMessage,
 420    {
 421        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 422        self
 423    }
 424
 425    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 426    where
 427        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 428        Fut: Send + Future<Output = Result<()>>,
 429        M: RequestMessage,
 430    {
 431        let handler = Arc::new(handler);
 432        self.add_handler(move |envelope, session| {
 433            let receipt = envelope.receipt();
 434            let handler = handler.clone();
 435            async move {
 436                let peer = session.peer.clone();
 437                let responded = Arc::new(AtomicBool::default());
 438                let response = Response {
 439                    peer: peer.clone(),
 440                    responded: responded.clone(),
 441                    receipt,
 442                };
 443                match (handler)(envelope.payload, response, session).await {
 444                    Ok(()) => {
 445                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 446                            Ok(())
 447                        } else {
 448                            Err(anyhow!("handler did not send a response"))?
 449                        }
 450                    }
 451                    Err(error) => {
 452                        peer.respond_with_error(
 453                            receipt,
 454                            proto::Error {
 455                                message: error.to_string(),
 456                            },
 457                        )?;
 458                        Err(error)
 459                    }
 460                }
 461            }
 462        })
 463    }
 464
 465    pub fn handle_connection(
 466        self: &Arc<Self>,
 467        connection: Connection,
 468        address: String,
 469        user: User,
 470        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 471        executor: Executor,
 472    ) -> impl Future<Output = Result<()>> {
 473        let this = self.clone();
 474        let user_id = user.id;
 475        let login = user.github_login;
 476        let span = info_span!("handle connection", %user_id, %login, %address);
 477        let mut teardown = self.teardown.subscribe();
 478        async move {
 479            let (connection_id, handle_io, mut incoming_rx) = this
 480                .peer
 481                .add_connection(connection, {
 482                    let executor = executor.clone();
 483                    move |duration| executor.sleep(duration)
 484                });
 485
 486            tracing::info!(%user_id, %login, ?connection_id, %address, "connection opened");
 487            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 488            tracing::info!(%user_id, %login, ?connection_id, %address, "sent hello message");
 489
 490            if let Some(send_connection_id) = send_connection_id.take() {
 491                let _ = send_connection_id.send(connection_id);
 492            }
 493
 494            if !user.connected_once {
 495                this.peer.send(connection_id, proto::ShowContacts {})?;
 496                this.app_state.db.set_user_connected_once(user_id, true).await?;
 497            }
 498
 499            let (contacts, invite_code) = future::try_join(
 500                this.app_state.db.get_contacts(user_id),
 501                this.app_state.db.get_invite_code_for_user(user_id)
 502            ).await?;
 503
 504            {
 505                let mut pool = this.connection_pool.lock();
 506                pool.add_connection(connection_id, user_id, user.admin);
 507                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 508
 509                if let Some((code, count)) = invite_code {
 510                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 511                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 512                        count: count as u32,
 513                    })?;
 514                }
 515            }
 516
 517            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 518                this.peer.send(connection_id, incoming_call)?;
 519            }
 520
 521            let session = Session {
 522                user_id,
 523                connection_id,
 524                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 525                peer: this.peer.clone(),
 526                connection_pool: this.connection_pool.clone(),
 527                live_kit_client: this.app_state.live_kit_client.clone()
 528            };
 529            update_user_contacts(user_id, &session).await?;
 530
 531            let handle_io = handle_io.fuse();
 532            futures::pin_mut!(handle_io);
 533
 534            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 535            // This prevents deadlocks when e.g., client A performs a request to client B and
 536            // client B performs a request to client A. If both clients stop processing further
 537            // messages until their respective request completes, they won't have a chance to
 538            // respond to the other client's request and cause a deadlock.
 539            //
 540            // This arrangement ensures we will attempt to process earlier messages first, but fall
 541            // back to processing messages arrived later in the spirit of making progress.
 542            let mut foreground_message_handlers = FuturesUnordered::new();
 543            loop {
 544                let next_message = incoming_rx.next().fuse();
 545                futures::pin_mut!(next_message);
 546                futures::select_biased! {
 547                    _ = teardown.changed().fuse() => return Ok(()),
 548                    result = handle_io => {
 549                        if let Err(error) = result {
 550                            tracing::error!(?error, %user_id, %login, ?connection_id, %address, "error handling I/O");
 551                        }
 552                        break;
 553                    }
 554                    _ = foreground_message_handlers.next() => {}
 555                    message = next_message => {
 556                        if let Some(message) = message {
 557                            let type_name = message.payload_type_name();
 558                            let span = tracing::info_span!("receive message", %user_id, %login, ?connection_id, %address, type_name);
 559                            let span_enter = span.enter();
 560                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 561                                let is_background = message.is_background();
 562                                let handle_message = (handler)(message, session.clone());
 563                                drop(span_enter);
 564
 565                                let handle_message = handle_message.instrument(span);
 566                                if is_background {
 567                                    executor.spawn_detached(handle_message);
 568                                } else {
 569                                    foreground_message_handlers.push(handle_message);
 570                                }
 571                            } else {
 572                                tracing::error!(%user_id, %login, ?connection_id, %address, "no message handler");
 573                            }
 574                        } else {
 575                            tracing::info!(%user_id, %login, ?connection_id, %address, "connection closed");
 576                            break;
 577                        }
 578                    }
 579                }
 580            }
 581
 582            drop(foreground_message_handlers);
 583            tracing::info!(%user_id, %login, ?connection_id, %address, "signing out");
 584            if let Err(error) = sign_out(session, teardown, executor).await {
 585                tracing::error!(%user_id, %login, ?connection_id, %address, ?error, "error signing out");
 586            }
 587
 588            Ok(())
 589        }.instrument(span)
 590    }
 591
 592    pub async fn invite_code_redeemed(
 593        self: &Arc<Self>,
 594        inviter_id: UserId,
 595        invitee_id: UserId,
 596    ) -> Result<()> {
 597        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 598            if let Some(code) = &user.invite_code {
 599                let pool = self.connection_pool.lock();
 600                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 601                for connection_id in pool.user_connection_ids(inviter_id) {
 602                    self.peer.send(
 603                        connection_id,
 604                        proto::UpdateContacts {
 605                            contacts: vec![invitee_contact.clone()],
 606                            ..Default::default()
 607                        },
 608                    )?;
 609                    self.peer.send(
 610                        connection_id,
 611                        proto::UpdateInviteInfo {
 612                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 613                            count: user.invite_count as u32,
 614                        },
 615                    )?;
 616                }
 617            }
 618        }
 619        Ok(())
 620    }
 621
 622    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 623        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 624            if let Some(invite_code) = &user.invite_code {
 625                let pool = self.connection_pool.lock();
 626                for connection_id in pool.user_connection_ids(user_id) {
 627                    self.peer.send(
 628                        connection_id,
 629                        proto::UpdateInviteInfo {
 630                            url: format!(
 631                                "{}{}",
 632                                self.app_state.config.invite_link_prefix, invite_code
 633                            ),
 634                            count: user.invite_count as u32,
 635                        },
 636                    )?;
 637                }
 638            }
 639        }
 640        Ok(())
 641    }
 642
 643    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 644        ServerSnapshot {
 645            connection_pool: ConnectionPoolGuard {
 646                guard: self.connection_pool.lock(),
 647                _not_send: PhantomData,
 648            },
 649            peer: &self.peer,
 650        }
 651    }
 652}
 653
 654impl<'a> Deref for ConnectionPoolGuard<'a> {
 655    type Target = ConnectionPool;
 656
 657    fn deref(&self) -> &Self::Target {
 658        &*self.guard
 659    }
 660}
 661
 662impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 663    fn deref_mut(&mut self) -> &mut Self::Target {
 664        &mut *self.guard
 665    }
 666}
 667
 668impl<'a> Drop for ConnectionPoolGuard<'a> {
 669    fn drop(&mut self) {
 670        #[cfg(test)]
 671        self.check_invariants();
 672    }
 673}
 674
 675fn broadcast<F>(
 676    sender_id: ConnectionId,
 677    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 678    mut f: F,
 679) where
 680    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 681{
 682    for receiver_id in receiver_ids {
 683        if receiver_id != sender_id {
 684            f(receiver_id).trace_err();
 685        }
 686    }
 687}
 688
 689lazy_static! {
 690    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 691}
 692
 693pub struct ProtocolVersion(u32);
 694
 695impl Header for ProtocolVersion {
 696    fn name() -> &'static HeaderName {
 697        &ZED_PROTOCOL_VERSION
 698    }
 699
 700    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 701    where
 702        Self: Sized,
 703        I: Iterator<Item = &'i axum::http::HeaderValue>,
 704    {
 705        let version = values
 706            .next()
 707            .ok_or_else(axum::headers::Error::invalid)?
 708            .to_str()
 709            .map_err(|_| axum::headers::Error::invalid())?
 710            .parse()
 711            .map_err(|_| axum::headers::Error::invalid())?;
 712        Ok(Self(version))
 713    }
 714
 715    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 716        values.extend([self.0.to_string().parse().unwrap()]);
 717    }
 718}
 719
 720pub fn routes(server: Arc<Server>) -> Router<Body> {
 721    Router::new()
 722        .route("/rpc", get(handle_websocket_request))
 723        .layer(
 724            ServiceBuilder::new()
 725                .layer(Extension(server.app_state.clone()))
 726                .layer(middleware::from_fn(auth::validate_header)),
 727        )
 728        .route("/metrics", get(handle_metrics))
 729        .layer(Extension(server))
 730}
 731
 732pub async fn handle_websocket_request(
 733    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 734    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 735    Extension(server): Extension<Arc<Server>>,
 736    Extension(user): Extension<User>,
 737    ws: WebSocketUpgrade,
 738) -> axum::response::Response {
 739    if protocol_version != rpc::PROTOCOL_VERSION {
 740        return (
 741            StatusCode::UPGRADE_REQUIRED,
 742            "client must be upgraded".to_string(),
 743        )
 744            .into_response();
 745    }
 746    let socket_address = socket_address.to_string();
 747    ws.on_upgrade(move |socket| {
 748        use util::ResultExt;
 749        let socket = socket
 750            .map_ok(to_tungstenite_message)
 751            .err_into()
 752            .with(|message| async move { Ok(to_axum_message(message)) });
 753        let connection = Connection::new(Box::pin(socket));
 754        async move {
 755            server
 756                .handle_connection(connection, socket_address, user, None, Executor::Production)
 757                .await
 758                .log_err();
 759        }
 760    })
 761}
 762
 763pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 764    let connections = server
 765        .connection_pool
 766        .lock()
 767        .connections()
 768        .filter(|connection| !connection.admin)
 769        .count();
 770
 771    METRIC_CONNECTIONS.set(connections as _);
 772
 773    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 774    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 775
 776    let encoder = prometheus::TextEncoder::new();
 777    let metric_families = prometheus::gather();
 778    let encoded_metrics = encoder
 779        .encode_to_string(&metric_families)
 780        .map_err(|err| anyhow!("{}", err))?;
 781    Ok(encoded_metrics)
 782}
 783
 784#[instrument(err, skip(executor))]
 785async fn sign_out(
 786    session: Session,
 787    mut teardown: watch::Receiver<()>,
 788    executor: Executor,
 789) -> Result<()> {
 790    session.peer.disconnect(session.connection_id);
 791    session
 792        .connection_pool()
 793        .await
 794        .remove_connection(session.connection_id)?;
 795
 796    if let Some(mut left_projects) = session
 797        .db()
 798        .await
 799        .connection_lost(session.connection_id)
 800        .await
 801        .trace_err()
 802    {
 803        for left_project in mem::take(&mut *left_projects) {
 804            project_left(&left_project, &session);
 805        }
 806    }
 807
 808    futures::select_biased! {
 809        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 810            leave_room_for_session(&session).await.trace_err();
 811
 812            if !session
 813                .connection_pool()
 814                .await
 815                .is_user_online(session.user_id)
 816            {
 817                let db = session.db().await;
 818                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 819                    room_updated(&room, &session.peer);
 820                }
 821            }
 822            update_user_contacts(session.user_id, &session).await?;
 823        }
 824        _ = teardown.changed().fuse() => {}
 825    }
 826
 827    Ok(())
 828}
 829
 830async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 831    response.send(proto::Ack {})?;
 832    Ok(())
 833}
 834
 835async fn create_room(
 836    _request: proto::CreateRoom,
 837    response: Response<proto::CreateRoom>,
 838    session: Session,
 839) -> Result<()> {
 840    let live_kit_room = nanoid::nanoid!(30);
 841    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 842        if let Some(_) = live_kit
 843            .create_room(live_kit_room.clone())
 844            .await
 845            .trace_err()
 846        {
 847            if let Some(token) = live_kit
 848                .room_token(&live_kit_room, &session.connection_id.to_string())
 849                .trace_err()
 850            {
 851                Some(proto::LiveKitConnectionInfo {
 852                    server_url: live_kit.url().into(),
 853                    token,
 854                })
 855            } else {
 856                None
 857            }
 858        } else {
 859            None
 860        }
 861    } else {
 862        None
 863    };
 864
 865    {
 866        let room = session
 867            .db()
 868            .await
 869            .create_room(session.user_id, session.connection_id, &live_kit_room)
 870            .await?;
 871
 872        response.send(proto::CreateRoomResponse {
 873            room: Some(room.clone()),
 874            live_kit_connection_info,
 875        })?;
 876    }
 877
 878    update_user_contacts(session.user_id, &session).await?;
 879    Ok(())
 880}
 881
 882async fn join_room(
 883    request: proto::JoinRoom,
 884    response: Response<proto::JoinRoom>,
 885    session: Session,
 886) -> Result<()> {
 887    let room_id = RoomId::from_proto(request.id);
 888    let room = {
 889        let room = session
 890            .db()
 891            .await
 892            .join_room(room_id, session.user_id, session.connection_id)
 893            .await?;
 894        room_updated(&room, &session.peer);
 895        room.clone()
 896    };
 897
 898    for connection_id in session
 899        .connection_pool()
 900        .await
 901        .user_connection_ids(session.user_id)
 902    {
 903        session
 904            .peer
 905            .send(
 906                connection_id,
 907                proto::CallCanceled {
 908                    room_id: room_id.to_proto(),
 909                },
 910            )
 911            .trace_err();
 912    }
 913
 914    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 915        if let Some(token) = live_kit
 916            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 917            .trace_err()
 918        {
 919            Some(proto::LiveKitConnectionInfo {
 920                server_url: live_kit.url().into(),
 921                token,
 922            })
 923        } else {
 924            None
 925        }
 926    } else {
 927        None
 928    };
 929
 930    response.send(proto::JoinRoomResponse {
 931        room: Some(room),
 932        live_kit_connection_info,
 933    })?;
 934
 935    update_user_contacts(session.user_id, &session).await?;
 936    Ok(())
 937}
 938
 939async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 940    leave_room_for_session(&session).await
 941}
 942
 943async fn call(
 944    request: proto::Call,
 945    response: Response<proto::Call>,
 946    session: Session,
 947) -> Result<()> {
 948    let room_id = RoomId::from_proto(request.room_id);
 949    let calling_user_id = session.user_id;
 950    let calling_connection_id = session.connection_id;
 951    let called_user_id = UserId::from_proto(request.called_user_id);
 952    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 953    if !session
 954        .db()
 955        .await
 956        .has_contact(calling_user_id, called_user_id)
 957        .await?
 958    {
 959        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 960    }
 961
 962    let incoming_call = {
 963        let (room, incoming_call) = &mut *session
 964            .db()
 965            .await
 966            .call(
 967                room_id,
 968                calling_user_id,
 969                calling_connection_id,
 970                called_user_id,
 971                initial_project_id,
 972            )
 973            .await?;
 974        room_updated(&room, &session.peer);
 975        mem::take(incoming_call)
 976    };
 977    update_user_contacts(called_user_id, &session).await?;
 978
 979    let mut calls = session
 980        .connection_pool()
 981        .await
 982        .user_connection_ids(called_user_id)
 983        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 984        .collect::<FuturesUnordered<_>>();
 985
 986    while let Some(call_response) = calls.next().await {
 987        match call_response.as_ref() {
 988            Ok(_) => {
 989                response.send(proto::Ack {})?;
 990                return Ok(());
 991            }
 992            Err(_) => {
 993                call_response.trace_err();
 994            }
 995        }
 996    }
 997
 998    {
 999        let room = session
1000            .db()
1001            .await
1002            .call_failed(room_id, called_user_id)
1003            .await?;
1004        room_updated(&room, &session.peer);
1005    }
1006    update_user_contacts(called_user_id, &session).await?;
1007
1008    Err(anyhow!("failed to ring user"))?
1009}
1010
1011async fn cancel_call(
1012    request: proto::CancelCall,
1013    response: Response<proto::CancelCall>,
1014    session: Session,
1015) -> Result<()> {
1016    let called_user_id = UserId::from_proto(request.called_user_id);
1017    let room_id = RoomId::from_proto(request.room_id);
1018    {
1019        let room = session
1020            .db()
1021            .await
1022            .cancel_call(Some(room_id), session.connection_id, called_user_id)
1023            .await?;
1024        room_updated(&room, &session.peer);
1025    }
1026
1027    for connection_id in session
1028        .connection_pool()
1029        .await
1030        .user_connection_ids(called_user_id)
1031    {
1032        session
1033            .peer
1034            .send(
1035                connection_id,
1036                proto::CallCanceled {
1037                    room_id: room_id.to_proto(),
1038                },
1039            )
1040            .trace_err();
1041    }
1042    response.send(proto::Ack {})?;
1043
1044    update_user_contacts(called_user_id, &session).await?;
1045    Ok(())
1046}
1047
1048async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1049    let room_id = RoomId::from_proto(message.room_id);
1050    {
1051        let room = session
1052            .db()
1053            .await
1054            .decline_call(Some(room_id), session.user_id)
1055            .await?;
1056        room_updated(&room, &session.peer);
1057    }
1058
1059    for connection_id in session
1060        .connection_pool()
1061        .await
1062        .user_connection_ids(session.user_id)
1063    {
1064        session
1065            .peer
1066            .send(
1067                connection_id,
1068                proto::CallCanceled {
1069                    room_id: room_id.to_proto(),
1070                },
1071            )
1072            .trace_err();
1073    }
1074    update_user_contacts(session.user_id, &session).await?;
1075    Ok(())
1076}
1077
1078async fn update_participant_location(
1079    request: proto::UpdateParticipantLocation,
1080    response: Response<proto::UpdateParticipantLocation>,
1081    session: Session,
1082) -> Result<()> {
1083    let room_id = RoomId::from_proto(request.room_id);
1084    let location = request
1085        .location
1086        .ok_or_else(|| anyhow!("invalid location"))?;
1087    let room = session
1088        .db()
1089        .await
1090        .update_room_participant_location(room_id, session.connection_id, location)
1091        .await?;
1092    room_updated(&room, &session.peer);
1093    response.send(proto::Ack {})?;
1094    Ok(())
1095}
1096
1097async fn share_project(
1098    request: proto::ShareProject,
1099    response: Response<proto::ShareProject>,
1100    session: Session,
1101) -> Result<()> {
1102    let (project_id, room) = &*session
1103        .db()
1104        .await
1105        .share_project(
1106            RoomId::from_proto(request.room_id),
1107            session.connection_id,
1108            &request.worktrees,
1109        )
1110        .await?;
1111    response.send(proto::ShareProjectResponse {
1112        project_id: project_id.to_proto(),
1113    })?;
1114    room_updated(&room, &session.peer);
1115
1116    Ok(())
1117}
1118
1119async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1120    let project_id = ProjectId::from_proto(message.project_id);
1121
1122    let (room, guest_connection_ids) = &*session
1123        .db()
1124        .await
1125        .unshare_project(project_id, session.connection_id)
1126        .await?;
1127
1128    broadcast(
1129        session.connection_id,
1130        guest_connection_ids.iter().copied(),
1131        |conn_id| session.peer.send(conn_id, message.clone()),
1132    );
1133    room_updated(&room, &session.peer);
1134
1135    Ok(())
1136}
1137
1138async fn join_project(
1139    request: proto::JoinProject,
1140    response: Response<proto::JoinProject>,
1141    session: Session,
1142) -> Result<()> {
1143    let project_id = ProjectId::from_proto(request.project_id);
1144    let guest_user_id = session.user_id;
1145
1146    tracing::info!(%project_id, "join project");
1147
1148    let (project, replica_id) = &mut *session
1149        .db()
1150        .await
1151        .join_project(project_id, session.connection_id)
1152        .await?;
1153
1154    let collaborators = project
1155        .collaborators
1156        .iter()
1157        .map(|collaborator| {
1158            let peer_id = proto::PeerId {
1159                epoch: collaborator.connection_server_id.0 as u32,
1160                id: collaborator.connection_id as u32,
1161            };
1162            proto::Collaborator {
1163                peer_id: Some(peer_id),
1164                replica_id: collaborator.replica_id.0 as u32,
1165                user_id: collaborator.user_id.to_proto(),
1166            }
1167        })
1168        .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1169        .collect::<Vec<_>>();
1170    let worktrees = project
1171        .worktrees
1172        .iter()
1173        .map(|(id, worktree)| proto::WorktreeMetadata {
1174            id: *id,
1175            root_name: worktree.root_name.clone(),
1176            visible: worktree.visible,
1177            abs_path: worktree.abs_path.clone(),
1178        })
1179        .collect::<Vec<_>>();
1180
1181    for collaborator in &collaborators {
1182        session
1183            .peer
1184            .send(
1185                collaborator.peer_id.unwrap().into(),
1186                proto::AddProjectCollaborator {
1187                    project_id: project_id.to_proto(),
1188                    collaborator: Some(proto::Collaborator {
1189                        peer_id: Some(session.connection_id.into()),
1190                        replica_id: replica_id.0 as u32,
1191                        user_id: guest_user_id.to_proto(),
1192                    }),
1193                },
1194            )
1195            .trace_err();
1196    }
1197
1198    // First, we send the metadata associated with each worktree.
1199    response.send(proto::JoinProjectResponse {
1200        worktrees: worktrees.clone(),
1201        replica_id: replica_id.0 as u32,
1202        collaborators: collaborators.clone(),
1203        language_servers: project.language_servers.clone(),
1204    })?;
1205
1206    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1207        #[cfg(any(test, feature = "test-support"))]
1208        const MAX_CHUNK_SIZE: usize = 2;
1209        #[cfg(not(any(test, feature = "test-support")))]
1210        const MAX_CHUNK_SIZE: usize = 256;
1211
1212        // Stream this worktree's entries.
1213        let message = proto::UpdateWorktree {
1214            project_id: project_id.to_proto(),
1215            worktree_id,
1216            abs_path: worktree.abs_path.clone(),
1217            root_name: worktree.root_name,
1218            updated_entries: worktree.entries,
1219            removed_entries: Default::default(),
1220            scan_id: worktree.scan_id,
1221            is_last_update: worktree.is_complete,
1222        };
1223        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1224            session.peer.send(session.connection_id, update.clone())?;
1225        }
1226
1227        // Stream this worktree's diagnostics.
1228        for summary in worktree.diagnostic_summaries {
1229            session.peer.send(
1230                session.connection_id,
1231                proto::UpdateDiagnosticSummary {
1232                    project_id: project_id.to_proto(),
1233                    worktree_id: worktree.id,
1234                    summary: Some(summary),
1235                },
1236            )?;
1237        }
1238    }
1239
1240    for language_server in &project.language_servers {
1241        session.peer.send(
1242            session.connection_id,
1243            proto::UpdateLanguageServer {
1244                project_id: project_id.to_proto(),
1245                language_server_id: language_server.id,
1246                variant: Some(
1247                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1248                        proto::LspDiskBasedDiagnosticsUpdated {},
1249                    ),
1250                ),
1251            },
1252        )?;
1253    }
1254
1255    Ok(())
1256}
1257
1258async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1259    let sender_id = session.connection_id;
1260    let project_id = ProjectId::from_proto(request.project_id);
1261
1262    let project = session
1263        .db()
1264        .await
1265        .leave_project(project_id, sender_id)
1266        .await?;
1267    tracing::info!(
1268        %project_id,
1269        host_user_id = %project.host_user_id,
1270        host_connection_id = %project.host_connection_id,
1271        "leave project"
1272    );
1273    project_left(&project, &session);
1274
1275    Ok(())
1276}
1277
1278async fn update_project(
1279    request: proto::UpdateProject,
1280    response: Response<proto::UpdateProject>,
1281    session: Session,
1282) -> Result<()> {
1283    let project_id = ProjectId::from_proto(request.project_id);
1284    let (room, guest_connection_ids) = &*session
1285        .db()
1286        .await
1287        .update_project(project_id, session.connection_id, &request.worktrees)
1288        .await?;
1289    broadcast(
1290        session.connection_id,
1291        guest_connection_ids.iter().copied(),
1292        |connection_id| {
1293            session
1294                .peer
1295                .forward_send(session.connection_id, connection_id, request.clone())
1296        },
1297    );
1298    room_updated(&room, &session.peer);
1299    response.send(proto::Ack {})?;
1300
1301    Ok(())
1302}
1303
1304async fn update_worktree(
1305    request: proto::UpdateWorktree,
1306    response: Response<proto::UpdateWorktree>,
1307    session: Session,
1308) -> Result<()> {
1309    let guest_connection_ids = session
1310        .db()
1311        .await
1312        .update_worktree(&request, session.connection_id)
1313        .await?;
1314
1315    broadcast(
1316        session.connection_id,
1317        guest_connection_ids.iter().copied(),
1318        |connection_id| {
1319            session
1320                .peer
1321                .forward_send(session.connection_id, connection_id, request.clone())
1322        },
1323    );
1324    response.send(proto::Ack {})?;
1325    Ok(())
1326}
1327
1328async fn update_diagnostic_summary(
1329    message: proto::UpdateDiagnosticSummary,
1330    session: Session,
1331) -> Result<()> {
1332    let guest_connection_ids = session
1333        .db()
1334        .await
1335        .update_diagnostic_summary(&message, session.connection_id)
1336        .await?;
1337
1338    broadcast(
1339        session.connection_id,
1340        guest_connection_ids.iter().copied(),
1341        |connection_id| {
1342            session
1343                .peer
1344                .forward_send(session.connection_id, connection_id, message.clone())
1345        },
1346    );
1347
1348    Ok(())
1349}
1350
1351async fn start_language_server(
1352    request: proto::StartLanguageServer,
1353    session: Session,
1354) -> Result<()> {
1355    let guest_connection_ids = session
1356        .db()
1357        .await
1358        .start_language_server(&request, session.connection_id)
1359        .await?;
1360
1361    broadcast(
1362        session.connection_id,
1363        guest_connection_ids.iter().copied(),
1364        |connection_id| {
1365            session
1366                .peer
1367                .forward_send(session.connection_id, connection_id, request.clone())
1368        },
1369    );
1370    Ok(())
1371}
1372
1373async fn update_language_server(
1374    request: proto::UpdateLanguageServer,
1375    session: Session,
1376) -> Result<()> {
1377    let project_id = ProjectId::from_proto(request.project_id);
1378    let project_connection_ids = session
1379        .db()
1380        .await
1381        .project_connection_ids(project_id, session.connection_id)
1382        .await?;
1383    broadcast(
1384        session.connection_id,
1385        project_connection_ids.iter().copied(),
1386        |connection_id| {
1387            session
1388                .peer
1389                .forward_send(session.connection_id, connection_id, request.clone())
1390        },
1391    );
1392    Ok(())
1393}
1394
1395async fn forward_project_request<T>(
1396    request: T,
1397    response: Response<T>,
1398    session: Session,
1399) -> Result<()>
1400where
1401    T: EntityMessage + RequestMessage,
1402{
1403    let project_id = ProjectId::from_proto(request.remote_entity_id());
1404    let host_connection_id = {
1405        let collaborators = session
1406            .db()
1407            .await
1408            .project_collaborators(project_id, session.connection_id)
1409            .await?;
1410        let host = collaborators
1411            .iter()
1412            .find(|collaborator| collaborator.is_host)
1413            .ok_or_else(|| anyhow!("host not found"))?;
1414        ConnectionId {
1415            epoch: host.connection_server_id.0 as u32,
1416            id: host.connection_id as u32,
1417        }
1418    };
1419
1420    let payload = session
1421        .peer
1422        .forward_request(session.connection_id, host_connection_id, request)
1423        .await?;
1424
1425    response.send(payload)?;
1426    Ok(())
1427}
1428
1429async fn save_buffer(
1430    request: proto::SaveBuffer,
1431    response: Response<proto::SaveBuffer>,
1432    session: Session,
1433) -> Result<()> {
1434    let project_id = ProjectId::from_proto(request.project_id);
1435    let host_connection_id = {
1436        let collaborators = session
1437            .db()
1438            .await
1439            .project_collaborators(project_id, session.connection_id)
1440            .await?;
1441        let host = collaborators
1442            .iter()
1443            .find(|collaborator| collaborator.is_host)
1444            .ok_or_else(|| anyhow!("host not found"))?;
1445        ConnectionId {
1446            epoch: host.connection_server_id.0 as u32,
1447            id: host.connection_id as u32,
1448        }
1449    };
1450    let response_payload = session
1451        .peer
1452        .forward_request(session.connection_id, host_connection_id, request.clone())
1453        .await?;
1454
1455    let mut collaborators = session
1456        .db()
1457        .await
1458        .project_collaborators(project_id, session.connection_id)
1459        .await?;
1460    collaborators.retain(|collaborator| {
1461        let collaborator_connection = ConnectionId {
1462            epoch: collaborator.connection_server_id.0 as u32,
1463            id: collaborator.connection_id as u32,
1464        };
1465        collaborator_connection != session.connection_id
1466    });
1467    let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1468        epoch: collaborator.connection_server_id.0 as u32,
1469        id: collaborator.connection_id as u32,
1470    });
1471    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1472        session
1473            .peer
1474            .forward_send(host_connection_id, conn_id, response_payload.clone())
1475    });
1476    response.send(response_payload)?;
1477    Ok(())
1478}
1479
1480async fn create_buffer_for_peer(
1481    request: proto::CreateBufferForPeer,
1482    session: Session,
1483) -> Result<()> {
1484    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1485    session
1486        .peer
1487        .forward_send(session.connection_id, peer_id.into(), request)?;
1488    Ok(())
1489}
1490
1491async fn update_buffer(
1492    request: proto::UpdateBuffer,
1493    response: Response<proto::UpdateBuffer>,
1494    session: Session,
1495) -> Result<()> {
1496    let project_id = ProjectId::from_proto(request.project_id);
1497    let project_connection_ids = session
1498        .db()
1499        .await
1500        .project_connection_ids(project_id, session.connection_id)
1501        .await?;
1502
1503    dbg!(session.connection_id, &*project_connection_ids);
1504    broadcast(
1505        session.connection_id,
1506        project_connection_ids.iter().copied(),
1507        |connection_id| {
1508            session
1509                .peer
1510                .forward_send(session.connection_id, connection_id, request.clone())
1511        },
1512    );
1513    response.send(proto::Ack {})?;
1514    Ok(())
1515}
1516
1517async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1518    let project_id = ProjectId::from_proto(request.project_id);
1519    let project_connection_ids = session
1520        .db()
1521        .await
1522        .project_connection_ids(project_id, session.connection_id)
1523        .await?;
1524
1525    broadcast(
1526        session.connection_id,
1527        project_connection_ids.iter().copied(),
1528        |connection_id| {
1529            session
1530                .peer
1531                .forward_send(session.connection_id, connection_id, request.clone())
1532        },
1533    );
1534    Ok(())
1535}
1536
1537async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1538    let project_id = ProjectId::from_proto(request.project_id);
1539    let project_connection_ids = session
1540        .db()
1541        .await
1542        .project_connection_ids(project_id, session.connection_id)
1543        .await?;
1544    broadcast(
1545        session.connection_id,
1546        project_connection_ids.iter().copied(),
1547        |connection_id| {
1548            session
1549                .peer
1550                .forward_send(session.connection_id, connection_id, request.clone())
1551        },
1552    );
1553    Ok(())
1554}
1555
1556async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1557    let project_id = ProjectId::from_proto(request.project_id);
1558    let project_connection_ids = session
1559        .db()
1560        .await
1561        .project_connection_ids(project_id, session.connection_id)
1562        .await?;
1563    broadcast(
1564        session.connection_id,
1565        project_connection_ids.iter().copied(),
1566        |connection_id| {
1567            session
1568                .peer
1569                .forward_send(session.connection_id, connection_id, request.clone())
1570        },
1571    );
1572    Ok(())
1573}
1574
1575async fn follow(
1576    request: proto::Follow,
1577    response: Response<proto::Follow>,
1578    session: Session,
1579) -> Result<()> {
1580    let project_id = ProjectId::from_proto(request.project_id);
1581    let leader_id = request
1582        .leader_id
1583        .ok_or_else(|| anyhow!("invalid leader id"))?
1584        .into();
1585    let follower_id = session.connection_id;
1586    {
1587        let project_connection_ids = session
1588            .db()
1589            .await
1590            .project_connection_ids(project_id, session.connection_id)
1591            .await?;
1592
1593        if !project_connection_ids.contains(&leader_id) {
1594            Err(anyhow!("no such peer"))?;
1595        }
1596    }
1597
1598    let mut response_payload = session
1599        .peer
1600        .forward_request(session.connection_id, leader_id, request)
1601        .await?;
1602    response_payload
1603        .views
1604        .retain(|view| view.leader_id != Some(follower_id.into()));
1605    response.send(response_payload)?;
1606    Ok(())
1607}
1608
1609async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1610    let project_id = ProjectId::from_proto(request.project_id);
1611    let leader_id = request
1612        .leader_id
1613        .ok_or_else(|| anyhow!("invalid leader id"))?
1614        .into();
1615    let project_connection_ids = session
1616        .db()
1617        .await
1618        .project_connection_ids(project_id, session.connection_id)
1619        .await?;
1620    if !project_connection_ids.contains(&leader_id) {
1621        Err(anyhow!("no such peer"))?;
1622    }
1623    session
1624        .peer
1625        .forward_send(session.connection_id, leader_id, request)?;
1626    Ok(())
1627}
1628
1629async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1630    let project_id = ProjectId::from_proto(request.project_id);
1631    let project_connection_ids = session
1632        .db
1633        .lock()
1634        .await
1635        .project_connection_ids(project_id, session.connection_id)
1636        .await?;
1637
1638    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1639        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1640        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1641        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1642    });
1643    for follower_peer_id in request.follower_ids.iter().copied() {
1644        let follower_connection_id = follower_peer_id.into();
1645        if project_connection_ids.contains(&follower_connection_id)
1646            && Some(follower_peer_id) != leader_id
1647        {
1648            session.peer.forward_send(
1649                session.connection_id,
1650                follower_connection_id,
1651                request.clone(),
1652            )?;
1653        }
1654    }
1655    Ok(())
1656}
1657
1658async fn get_users(
1659    request: proto::GetUsers,
1660    response: Response<proto::GetUsers>,
1661    session: Session,
1662) -> Result<()> {
1663    let user_ids = request
1664        .user_ids
1665        .into_iter()
1666        .map(UserId::from_proto)
1667        .collect();
1668    let users = session
1669        .db()
1670        .await
1671        .get_users_by_ids(user_ids)
1672        .await?
1673        .into_iter()
1674        .map(|user| proto::User {
1675            id: user.id.to_proto(),
1676            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1677            github_login: user.github_login,
1678        })
1679        .collect();
1680    response.send(proto::UsersResponse { users })?;
1681    Ok(())
1682}
1683
1684async fn fuzzy_search_users(
1685    request: proto::FuzzySearchUsers,
1686    response: Response<proto::FuzzySearchUsers>,
1687    session: Session,
1688) -> Result<()> {
1689    let query = request.query;
1690    let users = match query.len() {
1691        0 => vec![],
1692        1 | 2 => session
1693            .db()
1694            .await
1695            .get_user_by_github_account(&query, None)
1696            .await?
1697            .into_iter()
1698            .collect(),
1699        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1700    };
1701    let users = users
1702        .into_iter()
1703        .filter(|user| user.id != session.user_id)
1704        .map(|user| proto::User {
1705            id: user.id.to_proto(),
1706            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1707            github_login: user.github_login,
1708        })
1709        .collect();
1710    response.send(proto::UsersResponse { users })?;
1711    Ok(())
1712}
1713
1714async fn request_contact(
1715    request: proto::RequestContact,
1716    response: Response<proto::RequestContact>,
1717    session: Session,
1718) -> Result<()> {
1719    let requester_id = session.user_id;
1720    let responder_id = UserId::from_proto(request.responder_id);
1721    if requester_id == responder_id {
1722        return Err(anyhow!("cannot add yourself as a contact"))?;
1723    }
1724
1725    session
1726        .db()
1727        .await
1728        .send_contact_request(requester_id, responder_id)
1729        .await?;
1730
1731    // Update outgoing contact requests of requester
1732    let mut update = proto::UpdateContacts::default();
1733    update.outgoing_requests.push(responder_id.to_proto());
1734    for connection_id in session
1735        .connection_pool()
1736        .await
1737        .user_connection_ids(requester_id)
1738    {
1739        session.peer.send(connection_id, update.clone())?;
1740    }
1741
1742    // Update incoming contact requests of responder
1743    let mut update = proto::UpdateContacts::default();
1744    update
1745        .incoming_requests
1746        .push(proto::IncomingContactRequest {
1747            requester_id: requester_id.to_proto(),
1748            should_notify: true,
1749        });
1750    for connection_id in session
1751        .connection_pool()
1752        .await
1753        .user_connection_ids(responder_id)
1754    {
1755        session.peer.send(connection_id, update.clone())?;
1756    }
1757
1758    response.send(proto::Ack {})?;
1759    Ok(())
1760}
1761
1762async fn respond_to_contact_request(
1763    request: proto::RespondToContactRequest,
1764    response: Response<proto::RespondToContactRequest>,
1765    session: Session,
1766) -> Result<()> {
1767    let responder_id = session.user_id;
1768    let requester_id = UserId::from_proto(request.requester_id);
1769    let db = session.db().await;
1770    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1771        db.dismiss_contact_notification(responder_id, requester_id)
1772            .await?;
1773    } else {
1774        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1775
1776        db.respond_to_contact_request(responder_id, requester_id, accept)
1777            .await?;
1778        let requester_busy = db.is_user_busy(requester_id).await?;
1779        let responder_busy = db.is_user_busy(responder_id).await?;
1780
1781        let pool = session.connection_pool().await;
1782        // Update responder with new contact
1783        let mut update = proto::UpdateContacts::default();
1784        if accept {
1785            update
1786                .contacts
1787                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1788        }
1789        update
1790            .remove_incoming_requests
1791            .push(requester_id.to_proto());
1792        for connection_id in pool.user_connection_ids(responder_id) {
1793            session.peer.send(connection_id, update.clone())?;
1794        }
1795
1796        // Update requester with new contact
1797        let mut update = proto::UpdateContacts::default();
1798        if accept {
1799            update
1800                .contacts
1801                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1802        }
1803        update
1804            .remove_outgoing_requests
1805            .push(responder_id.to_proto());
1806        for connection_id in pool.user_connection_ids(requester_id) {
1807            session.peer.send(connection_id, update.clone())?;
1808        }
1809    }
1810
1811    response.send(proto::Ack {})?;
1812    Ok(())
1813}
1814
1815async fn remove_contact(
1816    request: proto::RemoveContact,
1817    response: Response<proto::RemoveContact>,
1818    session: Session,
1819) -> Result<()> {
1820    let requester_id = session.user_id;
1821    let responder_id = UserId::from_proto(request.user_id);
1822    let db = session.db().await;
1823    db.remove_contact(requester_id, responder_id).await?;
1824
1825    let pool = session.connection_pool().await;
1826    // Update outgoing contact requests of requester
1827    let mut update = proto::UpdateContacts::default();
1828    update
1829        .remove_outgoing_requests
1830        .push(responder_id.to_proto());
1831    for connection_id in pool.user_connection_ids(requester_id) {
1832        session.peer.send(connection_id, update.clone())?;
1833    }
1834
1835    // Update incoming contact requests of responder
1836    let mut update = proto::UpdateContacts::default();
1837    update
1838        .remove_incoming_requests
1839        .push(requester_id.to_proto());
1840    for connection_id in pool.user_connection_ids(responder_id) {
1841        session.peer.send(connection_id, update.clone())?;
1842    }
1843
1844    response.send(proto::Ack {})?;
1845    Ok(())
1846}
1847
1848async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1849    let project_id = ProjectId::from_proto(request.project_id);
1850    let project_connection_ids = session
1851        .db()
1852        .await
1853        .project_connection_ids(project_id, session.connection_id)
1854        .await?;
1855    broadcast(
1856        session.connection_id,
1857        project_connection_ids.iter().copied(),
1858        |connection_id| {
1859            session
1860                .peer
1861                .forward_send(session.connection_id, connection_id, request.clone())
1862        },
1863    );
1864    Ok(())
1865}
1866
1867async fn get_private_user_info(
1868    _request: proto::GetPrivateUserInfo,
1869    response: Response<proto::GetPrivateUserInfo>,
1870    session: Session,
1871) -> Result<()> {
1872    let metrics_id = session
1873        .db()
1874        .await
1875        .get_user_metrics_id(session.user_id)
1876        .await?;
1877    let user = session
1878        .db()
1879        .await
1880        .get_user_by_id(session.user_id)
1881        .await?
1882        .ok_or_else(|| anyhow!("user not found"))?;
1883    response.send(proto::GetPrivateUserInfoResponse {
1884        metrics_id,
1885        staff: user.admin,
1886    })?;
1887    Ok(())
1888}
1889
1890fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1891    match message {
1892        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1893        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1894        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1895        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1896        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1897            code: frame.code.into(),
1898            reason: frame.reason,
1899        })),
1900    }
1901}
1902
1903fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1904    match message {
1905        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1906        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1907        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1908        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1909        AxumMessage::Close(frame) => {
1910            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1911                code: frame.code.into(),
1912                reason: frame.reason,
1913            }))
1914        }
1915    }
1916}
1917
1918fn build_initial_contacts_update(
1919    contacts: Vec<db::Contact>,
1920    pool: &ConnectionPool,
1921) -> proto::UpdateContacts {
1922    let mut update = proto::UpdateContacts::default();
1923
1924    for contact in contacts {
1925        match contact {
1926            db::Contact::Accepted {
1927                user_id,
1928                should_notify,
1929                busy,
1930            } => {
1931                update
1932                    .contacts
1933                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1934            }
1935            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1936            db::Contact::Incoming {
1937                user_id,
1938                should_notify,
1939            } => update
1940                .incoming_requests
1941                .push(proto::IncomingContactRequest {
1942                    requester_id: user_id.to_proto(),
1943                    should_notify,
1944                }),
1945        }
1946    }
1947
1948    update
1949}
1950
1951fn contact_for_user(
1952    user_id: UserId,
1953    should_notify: bool,
1954    busy: bool,
1955    pool: &ConnectionPool,
1956) -> proto::Contact {
1957    proto::Contact {
1958        user_id: user_id.to_proto(),
1959        online: pool.is_user_online(user_id),
1960        busy,
1961        should_notify,
1962    }
1963}
1964
1965fn room_updated(room: &proto::Room, peer: &Peer) {
1966    for participant in &room.participants {
1967        if let Some(peer_id) = participant
1968            .peer_id
1969            .ok_or_else(|| anyhow!("invalid participant peer id"))
1970            .trace_err()
1971        {
1972            peer.send(
1973                peer_id.into(),
1974                proto::RoomUpdated {
1975                    room: Some(room.clone()),
1976                },
1977            )
1978            .trace_err();
1979        }
1980    }
1981}
1982
1983async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1984    let db = session.db().await;
1985    let contacts = db.get_contacts(user_id).await?;
1986    let busy = db.is_user_busy(user_id).await?;
1987
1988    let pool = session.connection_pool().await;
1989    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1990    for contact in contacts {
1991        if let db::Contact::Accepted {
1992            user_id: contact_user_id,
1993            ..
1994        } = contact
1995        {
1996            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1997                session
1998                    .peer
1999                    .send(
2000                        contact_conn_id,
2001                        proto::UpdateContacts {
2002                            contacts: vec![updated_contact.clone()],
2003                            remove_contacts: Default::default(),
2004                            incoming_requests: Default::default(),
2005                            remove_incoming_requests: Default::default(),
2006                            outgoing_requests: Default::default(),
2007                            remove_outgoing_requests: Default::default(),
2008                        },
2009                    )
2010                    .trace_err();
2011            }
2012        }
2013    }
2014    Ok(())
2015}
2016
2017async fn leave_room_for_session(session: &Session) -> Result<()> {
2018    let mut contacts_to_update = HashSet::default();
2019
2020    let room_id;
2021    let canceled_calls_to_user_ids;
2022    let live_kit_room;
2023    let delete_live_kit_room;
2024    {
2025        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
2026        contacts_to_update.insert(session.user_id);
2027
2028        for project in left_room.left_projects.values() {
2029            project_left(project, session);
2030        }
2031
2032        room_updated(&left_room.room, &session.peer);
2033        room_id = RoomId::from_proto(left_room.room.id);
2034        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2035        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2036        delete_live_kit_room = left_room.room.participants.is_empty();
2037    }
2038
2039    {
2040        let pool = session.connection_pool().await;
2041        for canceled_user_id in canceled_calls_to_user_ids {
2042            for connection_id in pool.user_connection_ids(canceled_user_id) {
2043                session
2044                    .peer
2045                    .send(
2046                        connection_id,
2047                        proto::CallCanceled {
2048                            room_id: room_id.to_proto(),
2049                        },
2050                    )
2051                    .trace_err();
2052            }
2053            contacts_to_update.insert(canceled_user_id);
2054        }
2055    }
2056
2057    for contact_user_id in contacts_to_update {
2058        update_user_contacts(contact_user_id, &session).await?;
2059    }
2060
2061    if let Some(live_kit) = session.live_kit_client.as_ref() {
2062        live_kit
2063            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2064            .await
2065            .trace_err();
2066
2067        if delete_live_kit_room {
2068            live_kit.delete_room(live_kit_room).await.trace_err();
2069        }
2070    }
2071
2072    Ok(())
2073}
2074
2075fn project_left(project: &db::LeftProject, session: &Session) {
2076    for connection_id in &project.connection_ids {
2077        if project.host_user_id == session.user_id {
2078            session
2079                .peer
2080                .send(
2081                    *connection_id,
2082                    proto::UnshareProject {
2083                        project_id: project.id.to_proto(),
2084                    },
2085                )
2086                .trace_err();
2087        } else {
2088            session
2089                .peer
2090                .send(
2091                    *connection_id,
2092                    proto::RemoveProjectCollaborator {
2093                        project_id: project.id.to_proto(),
2094                        peer_id: Some(session.connection_id.into()),
2095                    },
2096                )
2097                .trace_err();
2098        }
2099    }
2100
2101    session
2102        .peer
2103        .send(
2104            session.connection_id,
2105            proto::UnshareProject {
2106                project_id: project.id.to_proto(),
2107            },
2108        )
2109        .trace_err();
2110}
2111
2112pub trait ResultExt {
2113    type Ok;
2114
2115    fn trace_err(self) -> Option<Self::Ok>;
2116}
2117
2118impl<T, E> ResultExt for Result<T, E>
2119where
2120    E: std::fmt::Debug,
2121{
2122    type Ok = T;
2123
2124    fn trace_err(self) -> Option<T> {
2125        match self {
2126            Ok(value) => Some(value),
2127            Err(error) => {
2128                tracing::error!("{:?}", error);
2129                None
2130            }
2131        }
2132    }
2133}