rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    peer: Arc<Peer>,
 142    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    executor: Executor,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146    teardown: watch::Sender<()>,
 147}
 148
 149pub(crate) struct ConnectionPoolGuard<'a> {
 150    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 151    _not_send: PhantomData<Rc<()>>,
 152}
 153
 154#[derive(Serialize)]
 155pub struct ServerSnapshot<'a> {
 156    peer: &'a Peer,
 157    #[serde(serialize_with = "serialize_deref")]
 158    connection_pool: ConnectionPoolGuard<'a>,
 159}
 160
 161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 162where
 163    S: Serializer,
 164    T: Deref<Target = U>,
 165    U: Serialize,
 166{
 167    Serialize::serialize(value.deref(), serializer)
 168}
 169
 170impl Server {
 171    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 172        let mut server = Self {
 173            peer: Peer::new(),
 174            app_state,
 175            executor,
 176            connection_pool: Default::default(),
 177            handlers: Default::default(),
 178            teardown: watch::channel(()).0,
 179        };
 180
 181        server
 182            .add_request_handler(ping)
 183            .add_request_handler(create_room)
 184            .add_request_handler(join_room)
 185            .add_message_handler(leave_room)
 186            .add_request_handler(call)
 187            .add_request_handler(cancel_call)
 188            .add_message_handler(decline_call)
 189            .add_request_handler(update_participant_location)
 190            .add_request_handler(share_project)
 191            .add_message_handler(unshare_project)
 192            .add_request_handler(join_project)
 193            .add_message_handler(leave_project)
 194            .add_request_handler(update_project)
 195            .add_request_handler(update_worktree)
 196            .add_message_handler(start_language_server)
 197            .add_message_handler(update_language_server)
 198            .add_message_handler(update_diagnostic_summary)
 199            .add_request_handler(forward_project_request::<proto::GetHover>)
 200            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 201            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 202            .add_request_handler(forward_project_request::<proto::GetReferences>)
 203            .add_request_handler(forward_project_request::<proto::SearchProject>)
 204            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 205            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 206            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 209            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 210            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 211            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 213            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 214            .add_request_handler(forward_project_request::<proto::PerformRename>)
 215            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 216            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 217            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 218            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 221            .add_message_handler(create_buffer_for_peer)
 222            .add_request_handler(update_buffer)
 223            .add_message_handler(update_buffer_file)
 224            .add_message_handler(buffer_reloaded)
 225            .add_message_handler(buffer_saved)
 226            .add_request_handler(save_buffer)
 227            .add_request_handler(get_users)
 228            .add_request_handler(fuzzy_search_users)
 229            .add_request_handler(request_contact)
 230            .add_request_handler(remove_contact)
 231            .add_request_handler(respond_to_contact_request)
 232            .add_request_handler(follow)
 233            .add_message_handler(unfollow)
 234            .add_message_handler(update_followers)
 235            .add_message_handler(update_diff_base)
 236            .add_request_handler(get_private_user_info);
 237
 238        Arc::new(server)
 239    }
 240
 241    pub async fn start(&self) -> Result<()> {
 242        let db = self.app_state.db.clone();
 243        let peer = self.peer.clone();
 244        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 245        let pool = self.connection_pool.clone();
 246        let live_kit_client = self.app_state.live_kit_client.clone();
 247
 248        let span = info_span!("start server");
 249        let span_enter = span.enter();
 250
 251        tracing::info!("begin deleting stale projects");
 252        self.app_state.db.delete_stale_projects().await?;
 253        tracing::info!("finish deleting stale projects");
 254
 255        drop(span_enter);
 256        self.executor.spawn_detached(
 257            async move {
 258                tracing::info!("waiting for cleanup timeout");
 259                timeout.await;
 260                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 261                if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
 262                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 263                    for room_id in room_ids {
 264                        let mut contacts_to_update = HashSet::default();
 265                        let mut canceled_calls_to_user_ids = Vec::new();
 266                        let mut live_kit_room = String::new();
 267                        let mut delete_live_kit_room = false;
 268
 269                        if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 270                            tracing::info!(
 271                                room_id = room_id.0,
 272                                new_participant_count = refreshed_room.room.participants.len(),
 273                                "refreshed room"
 274                            );
 275                            room_updated(&refreshed_room.room, &peer);
 276                            contacts_to_update
 277                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 278                            contacts_to_update
 279                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 280                            canceled_calls_to_user_ids =
 281                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 282                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 283                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 284                        }
 285
 286                        {
 287                            let pool = pool.lock();
 288                            for canceled_user_id in canceled_calls_to_user_ids {
 289                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 290                                    peer.send(
 291                                        connection_id,
 292                                        proto::CallCanceled {
 293                                            room_id: room_id.to_proto(),
 294                                        },
 295                                    )
 296                                    .trace_err();
 297                                }
 298                            }
 299                        }
 300
 301                        for user_id in contacts_to_update {
 302                            let busy = db.is_user_busy(user_id).await.trace_err();
 303                            let contacts = db.get_contacts(user_id).await.trace_err();
 304                            if let Some((busy, contacts)) = busy.zip(contacts) {
 305                                let pool = pool.lock();
 306                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 307                                for contact in contacts {
 308                                    if let db::Contact::Accepted {
 309                                        user_id: contact_user_id,
 310                                        ..
 311                                    } = contact
 312                                    {
 313                                        for contact_conn_id in
 314                                            pool.user_connection_ids(contact_user_id)
 315                                        {
 316                                            peer.send(
 317                                                contact_conn_id,
 318                                                proto::UpdateContacts {
 319                                                    contacts: vec![updated_contact.clone()],
 320                                                    remove_contacts: Default::default(),
 321                                                    incoming_requests: Default::default(),
 322                                                    remove_incoming_requests: Default::default(),
 323                                                    outgoing_requests: Default::default(),
 324                                                    remove_outgoing_requests: Default::default(),
 325                                                },
 326                                            )
 327                                            .trace_err();
 328                                        }
 329                                    }
 330                                }
 331                            }
 332                        }
 333
 334                        if let Some(live_kit) = live_kit_client.as_ref() {
 335                            if delete_live_kit_room {
 336                                live_kit.delete_room(live_kit_room).await.trace_err();
 337                            }
 338                        }
 339                    }
 340                }
 341            }
 342            .instrument(span),
 343        );
 344        Ok(())
 345    }
 346
 347    pub fn teardown(&self) {
 348        self.peer.reset();
 349        self.connection_pool.lock().reset();
 350        let _ = self.teardown.send(());
 351    }
 352
 353    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 354    where
 355        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 356        Fut: 'static + Send + Future<Output = Result<()>>,
 357        M: EnvelopedMessage,
 358    {
 359        let prev_handler = self.handlers.insert(
 360            TypeId::of::<M>(),
 361            Box::new(move |envelope, session| {
 362                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 363                let span = info_span!(
 364                    "handle message",
 365                    payload_type = envelope.payload_type_name()
 366                );
 367                span.in_scope(|| {
 368                    tracing::info!(
 369                        payload_type = envelope.payload_type_name(),
 370                        "message received"
 371                    );
 372                });
 373                let future = (handler)(*envelope, session);
 374                async move {
 375                    if let Err(error) = future.await {
 376                        tracing::error!(%error, "error handling message");
 377                    }
 378                }
 379                .instrument(span)
 380                .boxed()
 381            }),
 382        );
 383        if prev_handler.is_some() {
 384            panic!("registered a handler for the same message twice");
 385        }
 386        self
 387    }
 388
 389    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 390    where
 391        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 392        Fut: 'static + Send + Future<Output = Result<()>>,
 393        M: EnvelopedMessage,
 394    {
 395        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 396        self
 397    }
 398
 399    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 400    where
 401        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 402        Fut: Send + Future<Output = Result<()>>,
 403        M: RequestMessage,
 404    {
 405        let handler = Arc::new(handler);
 406        self.add_handler(move |envelope, session| {
 407            let receipt = envelope.receipt();
 408            let handler = handler.clone();
 409            async move {
 410                let peer = session.peer.clone();
 411                let responded = Arc::new(AtomicBool::default());
 412                let response = Response {
 413                    peer: peer.clone(),
 414                    responded: responded.clone(),
 415                    receipt,
 416                };
 417                match (handler)(envelope.payload, response, session).await {
 418                    Ok(()) => {
 419                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 420                            Ok(())
 421                        } else {
 422                            Err(anyhow!("handler did not send a response"))?
 423                        }
 424                    }
 425                    Err(error) => {
 426                        peer.respond_with_error(
 427                            receipt,
 428                            proto::Error {
 429                                message: error.to_string(),
 430                            },
 431                        )?;
 432                        Err(error)
 433                    }
 434                }
 435            }
 436        })
 437    }
 438
 439    pub fn handle_connection(
 440        self: &Arc<Self>,
 441        connection: Connection,
 442        address: String,
 443        user: User,
 444        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 445        executor: Executor,
 446    ) -> impl Future<Output = Result<()>> {
 447        let this = self.clone();
 448        let user_id = user.id;
 449        let login = user.github_login;
 450        let span = info_span!("handle connection", %user_id, %login, %address);
 451        let mut teardown = self.teardown.subscribe();
 452        async move {
 453            let (connection_id, handle_io, mut incoming_rx) = this
 454                .peer
 455                .add_connection(connection, {
 456                    let executor = executor.clone();
 457                    move |duration| executor.sleep(duration)
 458                });
 459
 460            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 461            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 462            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 463
 464            if let Some(send_connection_id) = send_connection_id.take() {
 465                let _ = send_connection_id.send(connection_id);
 466            }
 467
 468            if !user.connected_once {
 469                this.peer.send(connection_id, proto::ShowContacts {})?;
 470                this.app_state.db.set_user_connected_once(user_id, true).await?;
 471            }
 472
 473            let (contacts, invite_code) = future::try_join(
 474                this.app_state.db.get_contacts(user_id),
 475                this.app_state.db.get_invite_code_for_user(user_id)
 476            ).await?;
 477
 478            {
 479                let mut pool = this.connection_pool.lock();
 480                pool.add_connection(connection_id, user_id, user.admin);
 481                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 482
 483                if let Some((code, count)) = invite_code {
 484                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 485                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 486                        count: count as u32,
 487                    })?;
 488                }
 489            }
 490
 491            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 492                this.peer.send(connection_id, incoming_call)?;
 493            }
 494
 495            let session = Session {
 496                user_id,
 497                connection_id,
 498                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 499                peer: this.peer.clone(),
 500                connection_pool: this.connection_pool.clone(),
 501                live_kit_client: this.app_state.live_kit_client.clone()
 502            };
 503            update_user_contacts(user_id, &session).await?;
 504
 505            let handle_io = handle_io.fuse();
 506            futures::pin_mut!(handle_io);
 507
 508            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 509            // This prevents deadlocks when e.g., client A performs a request to client B and
 510            // client B performs a request to client A. If both clients stop processing further
 511            // messages until their respective request completes, they won't have a chance to
 512            // respond to the other client's request and cause a deadlock.
 513            //
 514            // This arrangement ensures we will attempt to process earlier messages first, but fall
 515            // back to processing messages arrived later in the spirit of making progress.
 516            let mut foreground_message_handlers = FuturesUnordered::new();
 517            loop {
 518                let next_message = incoming_rx.next().fuse();
 519                futures::pin_mut!(next_message);
 520                futures::select_biased! {
 521                    _ = teardown.changed().fuse() => return Ok(()),
 522                    result = handle_io => {
 523                        if let Err(error) = result {
 524                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 525                        }
 526                        break;
 527                    }
 528                    _ = foreground_message_handlers.next() => {}
 529                    message = next_message => {
 530                        if let Some(message) = message {
 531                            let type_name = message.payload_type_name();
 532                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 533                            let span_enter = span.enter();
 534                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 535                                let is_background = message.is_background();
 536                                let handle_message = (handler)(message, session.clone());
 537                                drop(span_enter);
 538
 539                                let handle_message = handle_message.instrument(span);
 540                                if is_background {
 541                                    executor.spawn_detached(handle_message);
 542                                } else {
 543                                    foreground_message_handlers.push(handle_message);
 544                                }
 545                            } else {
 546                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 547                            }
 548                        } else {
 549                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 550                            break;
 551                        }
 552                    }
 553                }
 554            }
 555
 556            drop(foreground_message_handlers);
 557            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 558            if let Err(error) = sign_out(session, teardown, executor).await {
 559                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 560            }
 561
 562            Ok(())
 563        }.instrument(span)
 564    }
 565
 566    pub async fn invite_code_redeemed(
 567        self: &Arc<Self>,
 568        inviter_id: UserId,
 569        invitee_id: UserId,
 570    ) -> Result<()> {
 571        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 572            if let Some(code) = &user.invite_code {
 573                let pool = self.connection_pool.lock();
 574                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 575                for connection_id in pool.user_connection_ids(inviter_id) {
 576                    self.peer.send(
 577                        connection_id,
 578                        proto::UpdateContacts {
 579                            contacts: vec![invitee_contact.clone()],
 580                            ..Default::default()
 581                        },
 582                    )?;
 583                    self.peer.send(
 584                        connection_id,
 585                        proto::UpdateInviteInfo {
 586                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 587                            count: user.invite_count as u32,
 588                        },
 589                    )?;
 590                }
 591            }
 592        }
 593        Ok(())
 594    }
 595
 596    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 597        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 598            if let Some(invite_code) = &user.invite_code {
 599                let pool = self.connection_pool.lock();
 600                for connection_id in pool.user_connection_ids(user_id) {
 601                    self.peer.send(
 602                        connection_id,
 603                        proto::UpdateInviteInfo {
 604                            url: format!(
 605                                "{}{}",
 606                                self.app_state.config.invite_link_prefix, invite_code
 607                            ),
 608                            count: user.invite_count as u32,
 609                        },
 610                    )?;
 611                }
 612            }
 613        }
 614        Ok(())
 615    }
 616
 617    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 618        ServerSnapshot {
 619            connection_pool: ConnectionPoolGuard {
 620                guard: self.connection_pool.lock(),
 621                _not_send: PhantomData,
 622            },
 623            peer: &self.peer,
 624        }
 625    }
 626}
 627
 628impl<'a> Deref for ConnectionPoolGuard<'a> {
 629    type Target = ConnectionPool;
 630
 631    fn deref(&self) -> &Self::Target {
 632        &*self.guard
 633    }
 634}
 635
 636impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 637    fn deref_mut(&mut self) -> &mut Self::Target {
 638        &mut *self.guard
 639    }
 640}
 641
 642impl<'a> Drop for ConnectionPoolGuard<'a> {
 643    fn drop(&mut self) {
 644        #[cfg(test)]
 645        self.check_invariants();
 646    }
 647}
 648
 649fn broadcast<F>(
 650    sender_id: ConnectionId,
 651    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 652    mut f: F,
 653) where
 654    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 655{
 656    for receiver_id in receiver_ids {
 657        if receiver_id != sender_id {
 658            f(receiver_id).trace_err();
 659        }
 660    }
 661}
 662
 663lazy_static! {
 664    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 665}
 666
 667pub struct ProtocolVersion(u32);
 668
 669impl Header for ProtocolVersion {
 670    fn name() -> &'static HeaderName {
 671        &ZED_PROTOCOL_VERSION
 672    }
 673
 674    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 675    where
 676        Self: Sized,
 677        I: Iterator<Item = &'i axum::http::HeaderValue>,
 678    {
 679        let version = values
 680            .next()
 681            .ok_or_else(axum::headers::Error::invalid)?
 682            .to_str()
 683            .map_err(|_| axum::headers::Error::invalid())?
 684            .parse()
 685            .map_err(|_| axum::headers::Error::invalid())?;
 686        Ok(Self(version))
 687    }
 688
 689    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 690        values.extend([self.0.to_string().parse().unwrap()]);
 691    }
 692}
 693
 694pub fn routes(server: Arc<Server>) -> Router<Body> {
 695    Router::new()
 696        .route("/rpc", get(handle_websocket_request))
 697        .layer(
 698            ServiceBuilder::new()
 699                .layer(Extension(server.app_state.clone()))
 700                .layer(middleware::from_fn(auth::validate_header)),
 701        )
 702        .route("/metrics", get(handle_metrics))
 703        .layer(Extension(server))
 704}
 705
 706pub async fn handle_websocket_request(
 707    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 708    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 709    Extension(server): Extension<Arc<Server>>,
 710    Extension(user): Extension<User>,
 711    ws: WebSocketUpgrade,
 712) -> axum::response::Response {
 713    if protocol_version != rpc::PROTOCOL_VERSION {
 714        return (
 715            StatusCode::UPGRADE_REQUIRED,
 716            "client must be upgraded".to_string(),
 717        )
 718            .into_response();
 719    }
 720    let socket_address = socket_address.to_string();
 721    ws.on_upgrade(move |socket| {
 722        use util::ResultExt;
 723        let socket = socket
 724            .map_ok(to_tungstenite_message)
 725            .err_into()
 726            .with(|message| async move { Ok(to_axum_message(message)) });
 727        let connection = Connection::new(Box::pin(socket));
 728        async move {
 729            server
 730                .handle_connection(connection, socket_address, user, None, Executor::Production)
 731                .await
 732                .log_err();
 733        }
 734    })
 735}
 736
 737pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 738    let connections = server
 739        .connection_pool
 740        .lock()
 741        .connections()
 742        .filter(|connection| !connection.admin)
 743        .count();
 744
 745    METRIC_CONNECTIONS.set(connections as _);
 746
 747    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 748    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 749
 750    let encoder = prometheus::TextEncoder::new();
 751    let metric_families = prometheus::gather();
 752    let encoded_metrics = encoder
 753        .encode_to_string(&metric_families)
 754        .map_err(|err| anyhow!("{}", err))?;
 755    Ok(encoded_metrics)
 756}
 757
 758#[instrument(err, skip(executor))]
 759async fn sign_out(
 760    session: Session,
 761    mut teardown: watch::Receiver<()>,
 762    executor: Executor,
 763) -> Result<()> {
 764    session.peer.disconnect(session.connection_id);
 765    session
 766        .connection_pool()
 767        .await
 768        .remove_connection(session.connection_id)?;
 769
 770    if let Some(mut left_projects) = session
 771        .db()
 772        .await
 773        .connection_lost(session.connection_id)
 774        .await
 775        .trace_err()
 776    {
 777        for left_project in mem::take(&mut *left_projects) {
 778            project_left(&left_project, &session);
 779        }
 780    }
 781
 782    futures::select_biased! {
 783        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 784            leave_room_for_session(&session).await.trace_err();
 785
 786            if !session
 787                .connection_pool()
 788                .await
 789                .is_user_online(session.user_id)
 790            {
 791                let db = session.db().await;
 792                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 793                    room_updated(&room, &session.peer);
 794                }
 795            }
 796            update_user_contacts(session.user_id, &session).await?;
 797        }
 798        _ = teardown.changed().fuse() => {}
 799    }
 800
 801    Ok(())
 802}
 803
 804async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 805    response.send(proto::Ack {})?;
 806    Ok(())
 807}
 808
 809async fn create_room(
 810    _request: proto::CreateRoom,
 811    response: Response<proto::CreateRoom>,
 812    session: Session,
 813) -> Result<()> {
 814    let live_kit_room = nanoid::nanoid!(30);
 815    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 816        if let Some(_) = live_kit
 817            .create_room(live_kit_room.clone())
 818            .await
 819            .trace_err()
 820        {
 821            if let Some(token) = live_kit
 822                .room_token(&live_kit_room, &session.connection_id.to_string())
 823                .trace_err()
 824            {
 825                Some(proto::LiveKitConnectionInfo {
 826                    server_url: live_kit.url().into(),
 827                    token,
 828                })
 829            } else {
 830                None
 831            }
 832        } else {
 833            None
 834        }
 835    } else {
 836        None
 837    };
 838
 839    {
 840        let room = session
 841            .db()
 842            .await
 843            .create_room(session.user_id, session.connection_id, &live_kit_room)
 844            .await?;
 845
 846        response.send(proto::CreateRoomResponse {
 847            room: Some(room.clone()),
 848            live_kit_connection_info,
 849        })?;
 850    }
 851
 852    update_user_contacts(session.user_id, &session).await?;
 853    Ok(())
 854}
 855
 856async fn join_room(
 857    request: proto::JoinRoom,
 858    response: Response<proto::JoinRoom>,
 859    session: Session,
 860) -> Result<()> {
 861    let room_id = RoomId::from_proto(request.id);
 862    let room = {
 863        let room = session
 864            .db()
 865            .await
 866            .join_room(room_id, session.user_id, session.connection_id)
 867            .await?;
 868        room_updated(&room, &session.peer);
 869        room.clone()
 870    };
 871
 872    for connection_id in session
 873        .connection_pool()
 874        .await
 875        .user_connection_ids(session.user_id)
 876    {
 877        session
 878            .peer
 879            .send(
 880                connection_id,
 881                proto::CallCanceled {
 882                    room_id: room_id.to_proto(),
 883                },
 884            )
 885            .trace_err();
 886    }
 887
 888    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 889        if let Some(token) = live_kit
 890            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 891            .trace_err()
 892        {
 893            Some(proto::LiveKitConnectionInfo {
 894                server_url: live_kit.url().into(),
 895                token,
 896            })
 897        } else {
 898            None
 899        }
 900    } else {
 901        None
 902    };
 903
 904    response.send(proto::JoinRoomResponse {
 905        room: Some(room),
 906        live_kit_connection_info,
 907    })?;
 908
 909    update_user_contacts(session.user_id, &session).await?;
 910    Ok(())
 911}
 912
 913async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 914    leave_room_for_session(&session).await
 915}
 916
 917async fn call(
 918    request: proto::Call,
 919    response: Response<proto::Call>,
 920    session: Session,
 921) -> Result<()> {
 922    let room_id = RoomId::from_proto(request.room_id);
 923    let calling_user_id = session.user_id;
 924    let calling_connection_id = session.connection_id;
 925    let called_user_id = UserId::from_proto(request.called_user_id);
 926    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 927    if !session
 928        .db()
 929        .await
 930        .has_contact(calling_user_id, called_user_id)
 931        .await?
 932    {
 933        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 934    }
 935
 936    let incoming_call = {
 937        let (room, incoming_call) = &mut *session
 938            .db()
 939            .await
 940            .call(
 941                room_id,
 942                calling_user_id,
 943                calling_connection_id,
 944                called_user_id,
 945                initial_project_id,
 946            )
 947            .await?;
 948        room_updated(&room, &session.peer);
 949        mem::take(incoming_call)
 950    };
 951    update_user_contacts(called_user_id, &session).await?;
 952
 953    let mut calls = session
 954        .connection_pool()
 955        .await
 956        .user_connection_ids(called_user_id)
 957        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 958        .collect::<FuturesUnordered<_>>();
 959
 960    while let Some(call_response) = calls.next().await {
 961        match call_response.as_ref() {
 962            Ok(_) => {
 963                response.send(proto::Ack {})?;
 964                return Ok(());
 965            }
 966            Err(_) => {
 967                call_response.trace_err();
 968            }
 969        }
 970    }
 971
 972    {
 973        let room = session
 974            .db()
 975            .await
 976            .call_failed(room_id, called_user_id)
 977            .await?;
 978        room_updated(&room, &session.peer);
 979    }
 980    update_user_contacts(called_user_id, &session).await?;
 981
 982    Err(anyhow!("failed to ring user"))?
 983}
 984
 985async fn cancel_call(
 986    request: proto::CancelCall,
 987    response: Response<proto::CancelCall>,
 988    session: Session,
 989) -> Result<()> {
 990    let called_user_id = UserId::from_proto(request.called_user_id);
 991    let room_id = RoomId::from_proto(request.room_id);
 992    {
 993        let room = session
 994            .db()
 995            .await
 996            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 997            .await?;
 998        room_updated(&room, &session.peer);
 999    }
1000
1001    for connection_id in session
1002        .connection_pool()
1003        .await
1004        .user_connection_ids(called_user_id)
1005    {
1006        session
1007            .peer
1008            .send(
1009                connection_id,
1010                proto::CallCanceled {
1011                    room_id: room_id.to_proto(),
1012                },
1013            )
1014            .trace_err();
1015    }
1016    response.send(proto::Ack {})?;
1017
1018    update_user_contacts(called_user_id, &session).await?;
1019    Ok(())
1020}
1021
1022async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1023    let room_id = RoomId::from_proto(message.room_id);
1024    {
1025        let room = session
1026            .db()
1027            .await
1028            .decline_call(Some(room_id), session.user_id)
1029            .await?;
1030        room_updated(&room, &session.peer);
1031    }
1032
1033    for connection_id in session
1034        .connection_pool()
1035        .await
1036        .user_connection_ids(session.user_id)
1037    {
1038        session
1039            .peer
1040            .send(
1041                connection_id,
1042                proto::CallCanceled {
1043                    room_id: room_id.to_proto(),
1044                },
1045            )
1046            .trace_err();
1047    }
1048    update_user_contacts(session.user_id, &session).await?;
1049    Ok(())
1050}
1051
1052async fn update_participant_location(
1053    request: proto::UpdateParticipantLocation,
1054    response: Response<proto::UpdateParticipantLocation>,
1055    session: Session,
1056) -> Result<()> {
1057    let room_id = RoomId::from_proto(request.room_id);
1058    let location = request
1059        .location
1060        .ok_or_else(|| anyhow!("invalid location"))?;
1061    let room = session
1062        .db()
1063        .await
1064        .update_room_participant_location(room_id, session.connection_id, location)
1065        .await?;
1066    room_updated(&room, &session.peer);
1067    response.send(proto::Ack {})?;
1068    Ok(())
1069}
1070
1071async fn share_project(
1072    request: proto::ShareProject,
1073    response: Response<proto::ShareProject>,
1074    session: Session,
1075) -> Result<()> {
1076    let (project_id, room) = &*session
1077        .db()
1078        .await
1079        .share_project(
1080            RoomId::from_proto(request.room_id),
1081            session.connection_id,
1082            &request.worktrees,
1083        )
1084        .await?;
1085    response.send(proto::ShareProjectResponse {
1086        project_id: project_id.to_proto(),
1087    })?;
1088    room_updated(&room, &session.peer);
1089
1090    Ok(())
1091}
1092
1093async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1094    let project_id = ProjectId::from_proto(message.project_id);
1095
1096    let (room, guest_connection_ids) = &*session
1097        .db()
1098        .await
1099        .unshare_project(project_id, session.connection_id)
1100        .await?;
1101
1102    broadcast(
1103        session.connection_id,
1104        guest_connection_ids.iter().copied(),
1105        |conn_id| session.peer.send(conn_id, message.clone()),
1106    );
1107    room_updated(&room, &session.peer);
1108
1109    Ok(())
1110}
1111
1112async fn join_project(
1113    request: proto::JoinProject,
1114    response: Response<proto::JoinProject>,
1115    session: Session,
1116) -> Result<()> {
1117    let project_id = ProjectId::from_proto(request.project_id);
1118    let guest_user_id = session.user_id;
1119
1120    tracing::info!(%project_id, "join project");
1121
1122    let (project, replica_id) = &mut *session
1123        .db()
1124        .await
1125        .join_project(project_id, session.connection_id)
1126        .await?;
1127
1128    let collaborators = project
1129        .collaborators
1130        .iter()
1131        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1132        .map(|collaborator| proto::Collaborator {
1133            peer_id: collaborator.connection_id as u32,
1134            replica_id: collaborator.replica_id.0 as u32,
1135            user_id: collaborator.user_id.to_proto(),
1136        })
1137        .collect::<Vec<_>>();
1138    let worktrees = project
1139        .worktrees
1140        .iter()
1141        .map(|(id, worktree)| proto::WorktreeMetadata {
1142            id: *id,
1143            root_name: worktree.root_name.clone(),
1144            visible: worktree.visible,
1145            abs_path: worktree.abs_path.clone(),
1146        })
1147        .collect::<Vec<_>>();
1148
1149    for collaborator in &collaborators {
1150        session
1151            .peer
1152            .send(
1153                ConnectionId(collaborator.peer_id),
1154                proto::AddProjectCollaborator {
1155                    project_id: project_id.to_proto(),
1156                    collaborator: Some(proto::Collaborator {
1157                        peer_id: session.connection_id.0,
1158                        replica_id: replica_id.0 as u32,
1159                        user_id: guest_user_id.to_proto(),
1160                    }),
1161                },
1162            )
1163            .trace_err();
1164    }
1165
1166    // First, we send the metadata associated with each worktree.
1167    response.send(proto::JoinProjectResponse {
1168        worktrees: worktrees.clone(),
1169        replica_id: replica_id.0 as u32,
1170        collaborators: collaborators.clone(),
1171        language_servers: project.language_servers.clone(),
1172    })?;
1173
1174    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1175        #[cfg(any(test, feature = "test-support"))]
1176        const MAX_CHUNK_SIZE: usize = 2;
1177        #[cfg(not(any(test, feature = "test-support")))]
1178        const MAX_CHUNK_SIZE: usize = 256;
1179
1180        // Stream this worktree's entries.
1181        let message = proto::UpdateWorktree {
1182            project_id: project_id.to_proto(),
1183            worktree_id,
1184            abs_path: worktree.abs_path.clone(),
1185            root_name: worktree.root_name,
1186            updated_entries: worktree.entries,
1187            removed_entries: Default::default(),
1188            scan_id: worktree.scan_id,
1189            is_last_update: worktree.is_complete,
1190        };
1191        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1192            session.peer.send(session.connection_id, update.clone())?;
1193        }
1194
1195        // Stream this worktree's diagnostics.
1196        for summary in worktree.diagnostic_summaries {
1197            session.peer.send(
1198                session.connection_id,
1199                proto::UpdateDiagnosticSummary {
1200                    project_id: project_id.to_proto(),
1201                    worktree_id: worktree.id,
1202                    summary: Some(summary),
1203                },
1204            )?;
1205        }
1206    }
1207
1208    for language_server in &project.language_servers {
1209        session.peer.send(
1210            session.connection_id,
1211            proto::UpdateLanguageServer {
1212                project_id: project_id.to_proto(),
1213                language_server_id: language_server.id,
1214                variant: Some(
1215                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1216                        proto::LspDiskBasedDiagnosticsUpdated {},
1217                    ),
1218                ),
1219            },
1220        )?;
1221    }
1222
1223    Ok(())
1224}
1225
1226async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1227    let sender_id = session.connection_id;
1228    let project_id = ProjectId::from_proto(request.project_id);
1229
1230    let project = session
1231        .db()
1232        .await
1233        .leave_project(project_id, sender_id)
1234        .await?;
1235    tracing::info!(
1236        %project_id,
1237        host_user_id = %project.host_user_id,
1238        host_connection_id = %project.host_connection_id,
1239        "leave project"
1240    );
1241    project_left(&project, &session);
1242
1243    Ok(())
1244}
1245
1246async fn update_project(
1247    request: proto::UpdateProject,
1248    response: Response<proto::UpdateProject>,
1249    session: Session,
1250) -> Result<()> {
1251    let project_id = ProjectId::from_proto(request.project_id);
1252    let (room, guest_connection_ids) = &*session
1253        .db()
1254        .await
1255        .update_project(project_id, session.connection_id, &request.worktrees)
1256        .await?;
1257    broadcast(
1258        session.connection_id,
1259        guest_connection_ids.iter().copied(),
1260        |connection_id| {
1261            session
1262                .peer
1263                .forward_send(session.connection_id, connection_id, request.clone())
1264        },
1265    );
1266    room_updated(&room, &session.peer);
1267    response.send(proto::Ack {})?;
1268
1269    Ok(())
1270}
1271
1272async fn update_worktree(
1273    request: proto::UpdateWorktree,
1274    response: Response<proto::UpdateWorktree>,
1275    session: Session,
1276) -> Result<()> {
1277    let guest_connection_ids = session
1278        .db()
1279        .await
1280        .update_worktree(&request, session.connection_id)
1281        .await?;
1282
1283    broadcast(
1284        session.connection_id,
1285        guest_connection_ids.iter().copied(),
1286        |connection_id| {
1287            session
1288                .peer
1289                .forward_send(session.connection_id, connection_id, request.clone())
1290        },
1291    );
1292    response.send(proto::Ack {})?;
1293    Ok(())
1294}
1295
1296async fn update_diagnostic_summary(
1297    message: proto::UpdateDiagnosticSummary,
1298    session: Session,
1299) -> Result<()> {
1300    let guest_connection_ids = session
1301        .db()
1302        .await
1303        .update_diagnostic_summary(&message, session.connection_id)
1304        .await?;
1305
1306    broadcast(
1307        session.connection_id,
1308        guest_connection_ids.iter().copied(),
1309        |connection_id| {
1310            session
1311                .peer
1312                .forward_send(session.connection_id, connection_id, message.clone())
1313        },
1314    );
1315
1316    Ok(())
1317}
1318
1319async fn start_language_server(
1320    request: proto::StartLanguageServer,
1321    session: Session,
1322) -> Result<()> {
1323    let guest_connection_ids = session
1324        .db()
1325        .await
1326        .start_language_server(&request, session.connection_id)
1327        .await?;
1328
1329    broadcast(
1330        session.connection_id,
1331        guest_connection_ids.iter().copied(),
1332        |connection_id| {
1333            session
1334                .peer
1335                .forward_send(session.connection_id, connection_id, request.clone())
1336        },
1337    );
1338    Ok(())
1339}
1340
1341async fn update_language_server(
1342    request: proto::UpdateLanguageServer,
1343    session: Session,
1344) -> Result<()> {
1345    let project_id = ProjectId::from_proto(request.project_id);
1346    let project_connection_ids = session
1347        .db()
1348        .await
1349        .project_connection_ids(project_id, session.connection_id)
1350        .await?;
1351    broadcast(
1352        session.connection_id,
1353        project_connection_ids.iter().copied(),
1354        |connection_id| {
1355            session
1356                .peer
1357                .forward_send(session.connection_id, connection_id, request.clone())
1358        },
1359    );
1360    Ok(())
1361}
1362
1363async fn forward_project_request<T>(
1364    request: T,
1365    response: Response<T>,
1366    session: Session,
1367) -> Result<()>
1368where
1369    T: EntityMessage + RequestMessage,
1370{
1371    let project_id = ProjectId::from_proto(request.remote_entity_id());
1372    let host_connection_id = {
1373        let collaborators = session
1374            .db()
1375            .await
1376            .project_collaborators(project_id, session.connection_id)
1377            .await?;
1378        ConnectionId(
1379            collaborators
1380                .iter()
1381                .find(|collaborator| collaborator.is_host)
1382                .ok_or_else(|| anyhow!("host not found"))?
1383                .connection_id as u32,
1384        )
1385    };
1386
1387    let payload = session
1388        .peer
1389        .forward_request(session.connection_id, host_connection_id, request)
1390        .await?;
1391
1392    response.send(payload)?;
1393    Ok(())
1394}
1395
1396async fn save_buffer(
1397    request: proto::SaveBuffer,
1398    response: Response<proto::SaveBuffer>,
1399    session: Session,
1400) -> Result<()> {
1401    let project_id = ProjectId::from_proto(request.project_id);
1402    let host_connection_id = {
1403        let collaborators = session
1404            .db()
1405            .await
1406            .project_collaborators(project_id, session.connection_id)
1407            .await?;
1408        let host = collaborators
1409            .iter()
1410            .find(|collaborator| collaborator.is_host)
1411            .ok_or_else(|| anyhow!("host not found"))?;
1412        ConnectionId(host.connection_id as u32)
1413    };
1414    let response_payload = session
1415        .peer
1416        .forward_request(session.connection_id, host_connection_id, request.clone())
1417        .await?;
1418
1419    let mut collaborators = session
1420        .db()
1421        .await
1422        .project_collaborators(project_id, session.connection_id)
1423        .await?;
1424    collaborators
1425        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1426    let project_connection_ids = collaborators
1427        .iter()
1428        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1429    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1430        session
1431            .peer
1432            .forward_send(host_connection_id, conn_id, response_payload.clone())
1433    });
1434    response.send(response_payload)?;
1435    Ok(())
1436}
1437
1438async fn create_buffer_for_peer(
1439    request: proto::CreateBufferForPeer,
1440    session: Session,
1441) -> Result<()> {
1442    session.peer.forward_send(
1443        session.connection_id,
1444        ConnectionId(request.peer_id),
1445        request,
1446    )?;
1447    Ok(())
1448}
1449
1450async fn update_buffer(
1451    request: proto::UpdateBuffer,
1452    response: Response<proto::UpdateBuffer>,
1453    session: Session,
1454) -> Result<()> {
1455    let project_id = ProjectId::from_proto(request.project_id);
1456    let project_connection_ids = session
1457        .db()
1458        .await
1459        .project_connection_ids(project_id, session.connection_id)
1460        .await?;
1461
1462    broadcast(
1463        session.connection_id,
1464        project_connection_ids.iter().copied(),
1465        |connection_id| {
1466            session
1467                .peer
1468                .forward_send(session.connection_id, connection_id, request.clone())
1469        },
1470    );
1471    response.send(proto::Ack {})?;
1472    Ok(())
1473}
1474
1475async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1476    let project_id = ProjectId::from_proto(request.project_id);
1477    let project_connection_ids = session
1478        .db()
1479        .await
1480        .project_connection_ids(project_id, session.connection_id)
1481        .await?;
1482
1483    broadcast(
1484        session.connection_id,
1485        project_connection_ids.iter().copied(),
1486        |connection_id| {
1487            session
1488                .peer
1489                .forward_send(session.connection_id, connection_id, request.clone())
1490        },
1491    );
1492    Ok(())
1493}
1494
1495async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1496    let project_id = ProjectId::from_proto(request.project_id);
1497    let project_connection_ids = session
1498        .db()
1499        .await
1500        .project_connection_ids(project_id, session.connection_id)
1501        .await?;
1502    broadcast(
1503        session.connection_id,
1504        project_connection_ids.iter().copied(),
1505        |connection_id| {
1506            session
1507                .peer
1508                .forward_send(session.connection_id, connection_id, request.clone())
1509        },
1510    );
1511    Ok(())
1512}
1513
1514async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1515    let project_id = ProjectId::from_proto(request.project_id);
1516    let project_connection_ids = session
1517        .db()
1518        .await
1519        .project_connection_ids(project_id, session.connection_id)
1520        .await?;
1521    broadcast(
1522        session.connection_id,
1523        project_connection_ids.iter().copied(),
1524        |connection_id| {
1525            session
1526                .peer
1527                .forward_send(session.connection_id, connection_id, request.clone())
1528        },
1529    );
1530    Ok(())
1531}
1532
1533async fn follow(
1534    request: proto::Follow,
1535    response: Response<proto::Follow>,
1536    session: Session,
1537) -> Result<()> {
1538    let project_id = ProjectId::from_proto(request.project_id);
1539    let leader_id = ConnectionId(request.leader_id);
1540    let follower_id = session.connection_id;
1541    {
1542        let project_connection_ids = session
1543            .db()
1544            .await
1545            .project_connection_ids(project_id, session.connection_id)
1546            .await?;
1547
1548        if !project_connection_ids.contains(&leader_id) {
1549            Err(anyhow!("no such peer"))?;
1550        }
1551    }
1552
1553    let mut response_payload = session
1554        .peer
1555        .forward_request(session.connection_id, leader_id, request)
1556        .await?;
1557    response_payload
1558        .views
1559        .retain(|view| view.leader_id != Some(follower_id.0));
1560    response.send(response_payload)?;
1561    Ok(())
1562}
1563
1564async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1565    let project_id = ProjectId::from_proto(request.project_id);
1566    let leader_id = ConnectionId(request.leader_id);
1567    let project_connection_ids = session
1568        .db()
1569        .await
1570        .project_connection_ids(project_id, session.connection_id)
1571        .await?;
1572    if !project_connection_ids.contains(&leader_id) {
1573        Err(anyhow!("no such peer"))?;
1574    }
1575    session
1576        .peer
1577        .forward_send(session.connection_id, leader_id, request)?;
1578    Ok(())
1579}
1580
1581async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1582    let project_id = ProjectId::from_proto(request.project_id);
1583    let project_connection_ids = session
1584        .db
1585        .lock()
1586        .await
1587        .project_connection_ids(project_id, session.connection_id)
1588        .await?;
1589
1590    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1591        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1592        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1593        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1594    });
1595    for follower_id in &request.follower_ids {
1596        let follower_id = ConnectionId(*follower_id);
1597        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1598            session
1599                .peer
1600                .forward_send(session.connection_id, follower_id, request.clone())?;
1601        }
1602    }
1603    Ok(())
1604}
1605
1606async fn get_users(
1607    request: proto::GetUsers,
1608    response: Response<proto::GetUsers>,
1609    session: Session,
1610) -> Result<()> {
1611    let user_ids = request
1612        .user_ids
1613        .into_iter()
1614        .map(UserId::from_proto)
1615        .collect();
1616    let users = session
1617        .db()
1618        .await
1619        .get_users_by_ids(user_ids)
1620        .await?
1621        .into_iter()
1622        .map(|user| proto::User {
1623            id: user.id.to_proto(),
1624            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1625            github_login: user.github_login,
1626        })
1627        .collect();
1628    response.send(proto::UsersResponse { users })?;
1629    Ok(())
1630}
1631
1632async fn fuzzy_search_users(
1633    request: proto::FuzzySearchUsers,
1634    response: Response<proto::FuzzySearchUsers>,
1635    session: Session,
1636) -> Result<()> {
1637    let query = request.query;
1638    let users = match query.len() {
1639        0 => vec![],
1640        1 | 2 => session
1641            .db()
1642            .await
1643            .get_user_by_github_account(&query, None)
1644            .await?
1645            .into_iter()
1646            .collect(),
1647        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1648    };
1649    let users = users
1650        .into_iter()
1651        .filter(|user| user.id != session.user_id)
1652        .map(|user| proto::User {
1653            id: user.id.to_proto(),
1654            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1655            github_login: user.github_login,
1656        })
1657        .collect();
1658    response.send(proto::UsersResponse { users })?;
1659    Ok(())
1660}
1661
1662async fn request_contact(
1663    request: proto::RequestContact,
1664    response: Response<proto::RequestContact>,
1665    session: Session,
1666) -> Result<()> {
1667    let requester_id = session.user_id;
1668    let responder_id = UserId::from_proto(request.responder_id);
1669    if requester_id == responder_id {
1670        return Err(anyhow!("cannot add yourself as a contact"))?;
1671    }
1672
1673    session
1674        .db()
1675        .await
1676        .send_contact_request(requester_id, responder_id)
1677        .await?;
1678
1679    // Update outgoing contact requests of requester
1680    let mut update = proto::UpdateContacts::default();
1681    update.outgoing_requests.push(responder_id.to_proto());
1682    for connection_id in session
1683        .connection_pool()
1684        .await
1685        .user_connection_ids(requester_id)
1686    {
1687        session.peer.send(connection_id, update.clone())?;
1688    }
1689
1690    // Update incoming contact requests of responder
1691    let mut update = proto::UpdateContacts::default();
1692    update
1693        .incoming_requests
1694        .push(proto::IncomingContactRequest {
1695            requester_id: requester_id.to_proto(),
1696            should_notify: true,
1697        });
1698    for connection_id in session
1699        .connection_pool()
1700        .await
1701        .user_connection_ids(responder_id)
1702    {
1703        session.peer.send(connection_id, update.clone())?;
1704    }
1705
1706    response.send(proto::Ack {})?;
1707    Ok(())
1708}
1709
1710async fn respond_to_contact_request(
1711    request: proto::RespondToContactRequest,
1712    response: Response<proto::RespondToContactRequest>,
1713    session: Session,
1714) -> Result<()> {
1715    let responder_id = session.user_id;
1716    let requester_id = UserId::from_proto(request.requester_id);
1717    let db = session.db().await;
1718    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1719        db.dismiss_contact_notification(responder_id, requester_id)
1720            .await?;
1721    } else {
1722        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1723
1724        db.respond_to_contact_request(responder_id, requester_id, accept)
1725            .await?;
1726        let requester_busy = db.is_user_busy(requester_id).await?;
1727        let responder_busy = db.is_user_busy(responder_id).await?;
1728
1729        let pool = session.connection_pool().await;
1730        // Update responder with new contact
1731        let mut update = proto::UpdateContacts::default();
1732        if accept {
1733            update
1734                .contacts
1735                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1736        }
1737        update
1738            .remove_incoming_requests
1739            .push(requester_id.to_proto());
1740        for connection_id in pool.user_connection_ids(responder_id) {
1741            session.peer.send(connection_id, update.clone())?;
1742        }
1743
1744        // Update requester with new contact
1745        let mut update = proto::UpdateContacts::default();
1746        if accept {
1747            update
1748                .contacts
1749                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1750        }
1751        update
1752            .remove_outgoing_requests
1753            .push(responder_id.to_proto());
1754        for connection_id in pool.user_connection_ids(requester_id) {
1755            session.peer.send(connection_id, update.clone())?;
1756        }
1757    }
1758
1759    response.send(proto::Ack {})?;
1760    Ok(())
1761}
1762
1763async fn remove_contact(
1764    request: proto::RemoveContact,
1765    response: Response<proto::RemoveContact>,
1766    session: Session,
1767) -> Result<()> {
1768    let requester_id = session.user_id;
1769    let responder_id = UserId::from_proto(request.user_id);
1770    let db = session.db().await;
1771    db.remove_contact(requester_id, responder_id).await?;
1772
1773    let pool = session.connection_pool().await;
1774    // Update outgoing contact requests of requester
1775    let mut update = proto::UpdateContacts::default();
1776    update
1777        .remove_outgoing_requests
1778        .push(responder_id.to_proto());
1779    for connection_id in pool.user_connection_ids(requester_id) {
1780        session.peer.send(connection_id, update.clone())?;
1781    }
1782
1783    // Update incoming contact requests of responder
1784    let mut update = proto::UpdateContacts::default();
1785    update
1786        .remove_incoming_requests
1787        .push(requester_id.to_proto());
1788    for connection_id in pool.user_connection_ids(responder_id) {
1789        session.peer.send(connection_id, update.clone())?;
1790    }
1791
1792    response.send(proto::Ack {})?;
1793    Ok(())
1794}
1795
1796async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1797    let project_id = ProjectId::from_proto(request.project_id);
1798    let project_connection_ids = session
1799        .db()
1800        .await
1801        .project_connection_ids(project_id, session.connection_id)
1802        .await?;
1803    broadcast(
1804        session.connection_id,
1805        project_connection_ids.iter().copied(),
1806        |connection_id| {
1807            session
1808                .peer
1809                .forward_send(session.connection_id, connection_id, request.clone())
1810        },
1811    );
1812    Ok(())
1813}
1814
1815async fn get_private_user_info(
1816    _request: proto::GetPrivateUserInfo,
1817    response: Response<proto::GetPrivateUserInfo>,
1818    session: Session,
1819) -> Result<()> {
1820    let metrics_id = session
1821        .db()
1822        .await
1823        .get_user_metrics_id(session.user_id)
1824        .await?;
1825    let user = session
1826        .db()
1827        .await
1828        .get_user_by_id(session.user_id)
1829        .await?
1830        .ok_or_else(|| anyhow!("user not found"))?;
1831    response.send(proto::GetPrivateUserInfoResponse {
1832        metrics_id,
1833        staff: user.admin,
1834    })?;
1835    Ok(())
1836}
1837
1838fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1839    match message {
1840        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1841        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1842        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1843        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1844        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1845            code: frame.code.into(),
1846            reason: frame.reason,
1847        })),
1848    }
1849}
1850
1851fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1852    match message {
1853        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1854        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1855        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1856        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1857        AxumMessage::Close(frame) => {
1858            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1859                code: frame.code.into(),
1860                reason: frame.reason,
1861            }))
1862        }
1863    }
1864}
1865
1866fn build_initial_contacts_update(
1867    contacts: Vec<db::Contact>,
1868    pool: &ConnectionPool,
1869) -> proto::UpdateContacts {
1870    let mut update = proto::UpdateContacts::default();
1871
1872    for contact in contacts {
1873        match contact {
1874            db::Contact::Accepted {
1875                user_id,
1876                should_notify,
1877                busy,
1878            } => {
1879                update
1880                    .contacts
1881                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1882            }
1883            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1884            db::Contact::Incoming {
1885                user_id,
1886                should_notify,
1887            } => update
1888                .incoming_requests
1889                .push(proto::IncomingContactRequest {
1890                    requester_id: user_id.to_proto(),
1891                    should_notify,
1892                }),
1893        }
1894    }
1895
1896    update
1897}
1898
1899fn contact_for_user(
1900    user_id: UserId,
1901    should_notify: bool,
1902    busy: bool,
1903    pool: &ConnectionPool,
1904) -> proto::Contact {
1905    proto::Contact {
1906        user_id: user_id.to_proto(),
1907        online: pool.is_user_online(user_id),
1908        busy,
1909        should_notify,
1910    }
1911}
1912
1913fn room_updated(room: &proto::Room, peer: &Peer) {
1914    for participant in &room.participants {
1915        peer.send(
1916            ConnectionId(participant.peer_id),
1917            proto::RoomUpdated {
1918                room: Some(room.clone()),
1919            },
1920        )
1921        .trace_err();
1922    }
1923}
1924
1925async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1926    let db = session.db().await;
1927    let contacts = db.get_contacts(user_id).await?;
1928    let busy = db.is_user_busy(user_id).await?;
1929
1930    let pool = session.connection_pool().await;
1931    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1932    for contact in contacts {
1933        if let db::Contact::Accepted {
1934            user_id: contact_user_id,
1935            ..
1936        } = contact
1937        {
1938            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1939                session
1940                    .peer
1941                    .send(
1942                        contact_conn_id,
1943                        proto::UpdateContacts {
1944                            contacts: vec![updated_contact.clone()],
1945                            remove_contacts: Default::default(),
1946                            incoming_requests: Default::default(),
1947                            remove_incoming_requests: Default::default(),
1948                            outgoing_requests: Default::default(),
1949                            remove_outgoing_requests: Default::default(),
1950                        },
1951                    )
1952                    .trace_err();
1953            }
1954        }
1955    }
1956    Ok(())
1957}
1958
1959async fn leave_room_for_session(session: &Session) -> Result<()> {
1960    let mut contacts_to_update = HashSet::default();
1961
1962    let room_id;
1963    let canceled_calls_to_user_ids;
1964    let live_kit_room;
1965    let delete_live_kit_room;
1966    {
1967        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1968        contacts_to_update.insert(session.user_id);
1969
1970        for project in left_room.left_projects.values() {
1971            project_left(project, session);
1972        }
1973
1974        room_updated(&left_room.room, &session.peer);
1975        room_id = RoomId::from_proto(left_room.room.id);
1976        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1977        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1978        delete_live_kit_room = left_room.room.participants.is_empty();
1979    }
1980
1981    {
1982        let pool = session.connection_pool().await;
1983        for canceled_user_id in canceled_calls_to_user_ids {
1984            for connection_id in pool.user_connection_ids(canceled_user_id) {
1985                session
1986                    .peer
1987                    .send(
1988                        connection_id,
1989                        proto::CallCanceled {
1990                            room_id: room_id.to_proto(),
1991                        },
1992                    )
1993                    .trace_err();
1994            }
1995            contacts_to_update.insert(canceled_user_id);
1996        }
1997    }
1998
1999    for contact_user_id in contacts_to_update {
2000        update_user_contacts(contact_user_id, &session).await?;
2001    }
2002
2003    if let Some(live_kit) = session.live_kit_client.as_ref() {
2004        live_kit
2005            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2006            .await
2007            .trace_err();
2008
2009        if delete_live_kit_room {
2010            live_kit.delete_room(live_kit_room).await.trace_err();
2011        }
2012    }
2013
2014    Ok(())
2015}
2016
2017fn project_left(project: &db::LeftProject, session: &Session) {
2018    for connection_id in &project.connection_ids {
2019        if project.host_user_id == session.user_id {
2020            session
2021                .peer
2022                .send(
2023                    *connection_id,
2024                    proto::UnshareProject {
2025                        project_id: project.id.to_proto(),
2026                    },
2027                )
2028                .trace_err();
2029        } else {
2030            session
2031                .peer
2032                .send(
2033                    *connection_id,
2034                    proto::RemoveProjectCollaborator {
2035                        project_id: project.id.to_proto(),
2036                        peer_id: session.connection_id.0,
2037                    },
2038                )
2039                .trace_err();
2040        }
2041    }
2042
2043    session
2044        .peer
2045        .send(
2046            session.connection_id,
2047            proto::UnshareProject {
2048                project_id: project.id.to_proto(),
2049            },
2050        )
2051        .trace_err();
2052}
2053
2054pub trait ResultExt {
2055    type Ok;
2056
2057    fn trace_err(self) -> Option<Self::Ok>;
2058}
2059
2060impl<T, E> ResultExt for Result<T, E>
2061where
2062    E: std::fmt::Debug,
2063{
2064    type Ok = T;
2065
2066    fn trace_err(self) -> Option<T> {
2067        match self {
2068            Ok(value) => Some(value),
2069            Err(error) => {
2070                tracing::error!("{:?}", error);
2071                None
2072            }
2073        }
2074    }
2075}