rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::{watch, Mutex, MutexGuard};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock().await;
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    executor: Executor,
 146    handlers: HashMap<TypeId, MessageHandler>,
 147    teardown: watch::Sender<()>,
 148}
 149
 150pub(crate) struct ConnectionPoolGuard<'a> {
 151    guard: MutexGuard<'a, ConnectionPool>,
 152    _not_send: PhantomData<Rc<()>>,
 153}
 154
 155#[derive(Serialize)]
 156pub struct ServerSnapshot<'a> {
 157    peer: &'a Peer,
 158    #[serde(serialize_with = "serialize_deref")]
 159    connection_pool: ConnectionPoolGuard<'a>,
 160}
 161
 162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 163where
 164    S: Serializer,
 165    T: Deref<Target = U>,
 166    U: Serialize,
 167{
 168    Serialize::serialize(value.deref(), serializer)
 169}
 170
 171impl Server {
 172    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 173        let mut server = Self {
 174            peer: Peer::new(),
 175            app_state,
 176            executor,
 177            connection_pool: Default::default(),
 178            handlers: Default::default(),
 179            teardown: watch::channel(()).0,
 180        };
 181
 182        server
 183            .add_request_handler(ping)
 184            .add_request_handler(create_room)
 185            .add_request_handler(join_room)
 186            .add_message_handler(leave_room)
 187            .add_request_handler(call)
 188            .add_request_handler(cancel_call)
 189            .add_message_handler(decline_call)
 190            .add_request_handler(update_participant_location)
 191            .add_request_handler(share_project)
 192            .add_message_handler(unshare_project)
 193            .add_request_handler(join_project)
 194            .add_message_handler(leave_project)
 195            .add_request_handler(update_project)
 196            .add_request_handler(update_worktree)
 197            .add_message_handler(start_language_server)
 198            .add_message_handler(update_language_server)
 199            .add_message_handler(update_diagnostic_summary)
 200            .add_request_handler(forward_project_request::<proto::GetHover>)
 201            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 202            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 203            .add_request_handler(forward_project_request::<proto::GetReferences>)
 204            .add_request_handler(forward_project_request::<proto::SearchProject>)
 205            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 206            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 209            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 210            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 211            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 212            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 213            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 214            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 215            .add_request_handler(forward_project_request::<proto::PerformRename>)
 216            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 217            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 218            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 221            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 222            .add_message_handler(create_buffer_for_peer)
 223            .add_request_handler(update_buffer)
 224            .add_message_handler(update_buffer_file)
 225            .add_message_handler(buffer_reloaded)
 226            .add_message_handler(buffer_saved)
 227            .add_request_handler(save_buffer)
 228            .add_request_handler(get_users)
 229            .add_request_handler(fuzzy_search_users)
 230            .add_request_handler(request_contact)
 231            .add_request_handler(remove_contact)
 232            .add_request_handler(respond_to_contact_request)
 233            .add_request_handler(follow)
 234            .add_message_handler(unfollow)
 235            .add_message_handler(update_followers)
 236            .add_message_handler(update_diff_base)
 237            .add_request_handler(get_private_user_info);
 238
 239        Arc::new(server)
 240    }
 241
 242    pub async fn start(&self) -> Result<()> {
 243        self.app_state.db.delete_stale_projects().await?;
 244        let db = self.app_state.db.clone();
 245        let peer = self.peer.clone();
 246        let timeout = self.executor.sleep(RECONNECT_TIMEOUT);
 247        let pool = self.connection_pool.clone();
 248        let live_kit_client = self.app_state.live_kit_client.clone();
 249        self.executor.spawn_detached(async move {
 250            timeout.await;
 251            if let Some(room_ids) = db.outdated_room_ids().await.trace_err() {
 252                for room_id in room_ids {
 253                    let mut contacts_to_update = HashSet::default();
 254                    let mut canceled_calls_to_user_ids = Vec::new();
 255                    let mut live_kit_room = String::new();
 256                    let mut delete_live_kit_room = false;
 257
 258                    if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 259                        room_updated(&refreshed_room.room, &peer);
 260                        contacts_to_update
 261                            .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 262                        contacts_to_update
 263                            .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 264                        canceled_calls_to_user_ids =
 265                            mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 266                        live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 267                        delete_live_kit_room = refreshed_room.room.participants.is_empty();
 268                    }
 269
 270                    {
 271                        let pool = pool.lock().await;
 272                        for canceled_user_id in canceled_calls_to_user_ids {
 273                            for connection_id in pool.user_connection_ids(canceled_user_id) {
 274                                peer.send(
 275                                    connection_id,
 276                                    proto::CallCanceled {
 277                                        room_id: room_id.to_proto(),
 278                                    },
 279                                )
 280                                .trace_err();
 281                            }
 282                        }
 283                    }
 284
 285                    for user_id in contacts_to_update {
 286                        let busy = db.is_user_busy(user_id).await.trace_err();
 287                        let contacts = db.get_contacts(user_id).await.trace_err();
 288                        if let Some((busy, contacts)) = busy.zip(contacts) {
 289                            let pool = pool.lock().await;
 290                            let updated_contact = contact_for_user(user_id, false, busy, &pool);
 291                            for contact in contacts {
 292                                if let db::Contact::Accepted {
 293                                    user_id: contact_user_id,
 294                                    ..
 295                                } = contact
 296                                {
 297                                    for contact_conn_id in pool.user_connection_ids(contact_user_id)
 298                                    {
 299                                        peer.send(
 300                                            contact_conn_id,
 301                                            proto::UpdateContacts {
 302                                                contacts: vec![updated_contact.clone()],
 303                                                remove_contacts: Default::default(),
 304                                                incoming_requests: Default::default(),
 305                                                remove_incoming_requests: Default::default(),
 306                                                outgoing_requests: Default::default(),
 307                                                remove_outgoing_requests: Default::default(),
 308                                            },
 309                                        )
 310                                        .trace_err();
 311                                    }
 312                                }
 313                            }
 314                        }
 315                    }
 316
 317                    if let Some(live_kit) = live_kit_client.as_ref() {
 318                        if delete_live_kit_room {
 319                            live_kit.delete_room(live_kit_room).await.trace_err();
 320                        }
 321                    }
 322                }
 323            }
 324        });
 325        Ok(())
 326    }
 327
 328    pub fn teardown(&self) {
 329        self.peer.reset();
 330        let _ = self.teardown.send(());
 331    }
 332
 333    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 334    where
 335        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 336        Fut: 'static + Send + Future<Output = Result<()>>,
 337        M: EnvelopedMessage,
 338    {
 339        let prev_handler = self.handlers.insert(
 340            TypeId::of::<M>(),
 341            Box::new(move |envelope, session| {
 342                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 343                let span = info_span!(
 344                    "handle message",
 345                    payload_type = envelope.payload_type_name()
 346                );
 347                span.in_scope(|| {
 348                    tracing::info!(
 349                        payload_type = envelope.payload_type_name(),
 350                        "message received"
 351                    );
 352                });
 353                let future = (handler)(*envelope, session);
 354                async move {
 355                    if let Err(error) = future.await {
 356                        tracing::error!(%error, "error handling message");
 357                    }
 358                }
 359                .instrument(span)
 360                .boxed()
 361            }),
 362        );
 363        if prev_handler.is_some() {
 364            panic!("registered a handler for the same message twice");
 365        }
 366        self
 367    }
 368
 369    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 370    where
 371        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 372        Fut: 'static + Send + Future<Output = Result<()>>,
 373        M: EnvelopedMessage,
 374    {
 375        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 376        self
 377    }
 378
 379    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 380    where
 381        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 382        Fut: Send + Future<Output = Result<()>>,
 383        M: RequestMessage,
 384    {
 385        let handler = Arc::new(handler);
 386        self.add_handler(move |envelope, session| {
 387            let receipt = envelope.receipt();
 388            let handler = handler.clone();
 389            async move {
 390                let peer = session.peer.clone();
 391                let responded = Arc::new(AtomicBool::default());
 392                let response = Response {
 393                    peer: peer.clone(),
 394                    responded: responded.clone(),
 395                    receipt,
 396                };
 397                match (handler)(envelope.payload, response, session).await {
 398                    Ok(()) => {
 399                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 400                            Ok(())
 401                        } else {
 402                            Err(anyhow!("handler did not send a response"))?
 403                        }
 404                    }
 405                    Err(error) => {
 406                        peer.respond_with_error(
 407                            receipt,
 408                            proto::Error {
 409                                message: error.to_string(),
 410                            },
 411                        )?;
 412                        Err(error)
 413                    }
 414                }
 415            }
 416        })
 417    }
 418
 419    pub fn handle_connection(
 420        self: &Arc<Self>,
 421        connection: Connection,
 422        address: String,
 423        user: User,
 424        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 425        executor: Executor,
 426    ) -> impl Future<Output = Result<()>> {
 427        let this = self.clone();
 428        let user_id = user.id;
 429        let login = user.github_login;
 430        let span = info_span!("handle connection", %user_id, %login, %address);
 431        let mut teardown = self.teardown.subscribe();
 432        async move {
 433            let (connection_id, handle_io, mut incoming_rx) = this
 434                .peer
 435                .add_connection(connection, {
 436                    let executor = executor.clone();
 437                    move |duration| executor.sleep(duration)
 438                });
 439
 440            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 441            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 442            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 443
 444            if let Some(send_connection_id) = send_connection_id.take() {
 445                let _ = send_connection_id.send(connection_id);
 446            }
 447
 448            if !user.connected_once {
 449                this.peer.send(connection_id, proto::ShowContacts {})?;
 450                this.app_state.db.set_user_connected_once(user_id, true).await?;
 451            }
 452
 453            let (contacts, invite_code) = future::try_join(
 454                this.app_state.db.get_contacts(user_id),
 455                this.app_state.db.get_invite_code_for_user(user_id)
 456            ).await?;
 457
 458            {
 459                let mut pool = this.connection_pool.lock().await;
 460                pool.add_connection(connection_id, user_id, user.admin);
 461                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 462
 463                if let Some((code, count)) = invite_code {
 464                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 465                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 466                        count: count as u32,
 467                    })?;
 468                }
 469            }
 470
 471            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 472                this.peer.send(connection_id, incoming_call)?;
 473            }
 474
 475            let session = Session {
 476                user_id,
 477                connection_id,
 478                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 479                peer: this.peer.clone(),
 480                connection_pool: this.connection_pool.clone(),
 481                live_kit_client: this.app_state.live_kit_client.clone()
 482            };
 483            update_user_contacts(user_id, &session).await?;
 484
 485            let handle_io = handle_io.fuse();
 486            futures::pin_mut!(handle_io);
 487
 488            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 489            // This prevents deadlocks when e.g., client A performs a request to client B and
 490            // client B performs a request to client A. If both clients stop processing further
 491            // messages until their respective request completes, they won't have a chance to
 492            // respond to the other client's request and cause a deadlock.
 493            //
 494            // This arrangement ensures we will attempt to process earlier messages first, but fall
 495            // back to processing messages arrived later in the spirit of making progress.
 496            let mut foreground_message_handlers = FuturesUnordered::new();
 497            loop {
 498                let next_message = incoming_rx.next().fuse();
 499                futures::pin_mut!(next_message);
 500                futures::select_biased! {
 501                    _ = teardown.changed().fuse() => return Ok(()),
 502                    result = handle_io => {
 503                        if let Err(error) = result {
 504                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 505                        }
 506                        break;
 507                    }
 508                    _ = foreground_message_handlers.next() => {}
 509                    message = next_message => {
 510                        if let Some(message) = message {
 511                            let type_name = message.payload_type_name();
 512                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 513                            let span_enter = span.enter();
 514                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 515                                let is_background = message.is_background();
 516                                let handle_message = (handler)(message, session.clone());
 517                                drop(span_enter);
 518
 519                                let handle_message = handle_message.instrument(span);
 520                                if is_background {
 521                                    executor.spawn_detached(handle_message);
 522                                } else {
 523                                    foreground_message_handlers.push(handle_message);
 524                                }
 525                            } else {
 526                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 527                            }
 528                        } else {
 529                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 530                            break;
 531                        }
 532                    }
 533                }
 534            }
 535
 536            drop(foreground_message_handlers);
 537            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 538            if let Err(error) = sign_out(session, teardown, executor).await {
 539                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 540            }
 541
 542            Ok(())
 543        }.instrument(span)
 544    }
 545
 546    pub async fn invite_code_redeemed(
 547        self: &Arc<Self>,
 548        inviter_id: UserId,
 549        invitee_id: UserId,
 550    ) -> Result<()> {
 551        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 552            if let Some(code) = &user.invite_code {
 553                let pool = self.connection_pool.lock().await;
 554                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 555                for connection_id in pool.user_connection_ids(inviter_id) {
 556                    self.peer.send(
 557                        connection_id,
 558                        proto::UpdateContacts {
 559                            contacts: vec![invitee_contact.clone()],
 560                            ..Default::default()
 561                        },
 562                    )?;
 563                    self.peer.send(
 564                        connection_id,
 565                        proto::UpdateInviteInfo {
 566                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 567                            count: user.invite_count as u32,
 568                        },
 569                    )?;
 570                }
 571            }
 572        }
 573        Ok(())
 574    }
 575
 576    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 577        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 578            if let Some(invite_code) = &user.invite_code {
 579                let pool = self.connection_pool.lock().await;
 580                for connection_id in pool.user_connection_ids(user_id) {
 581                    self.peer.send(
 582                        connection_id,
 583                        proto::UpdateInviteInfo {
 584                            url: format!(
 585                                "{}{}",
 586                                self.app_state.config.invite_link_prefix, invite_code
 587                            ),
 588                            count: user.invite_count as u32,
 589                        },
 590                    )?;
 591                }
 592            }
 593        }
 594        Ok(())
 595    }
 596
 597    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 598        ServerSnapshot {
 599            connection_pool: ConnectionPoolGuard {
 600                guard: self.connection_pool.lock().await,
 601                _not_send: PhantomData,
 602            },
 603            peer: &self.peer,
 604        }
 605    }
 606}
 607
 608impl<'a> Deref for ConnectionPoolGuard<'a> {
 609    type Target = ConnectionPool;
 610
 611    fn deref(&self) -> &Self::Target {
 612        &*self.guard
 613    }
 614}
 615
 616impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 617    fn deref_mut(&mut self) -> &mut Self::Target {
 618        &mut *self.guard
 619    }
 620}
 621
 622impl<'a> Drop for ConnectionPoolGuard<'a> {
 623    fn drop(&mut self) {
 624        #[cfg(test)]
 625        self.check_invariants();
 626    }
 627}
 628
 629fn broadcast<F>(
 630    sender_id: ConnectionId,
 631    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 632    mut f: F,
 633) where
 634    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 635{
 636    for receiver_id in receiver_ids {
 637        if receiver_id != sender_id {
 638            f(receiver_id).trace_err();
 639        }
 640    }
 641}
 642
 643lazy_static! {
 644    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 645}
 646
 647pub struct ProtocolVersion(u32);
 648
 649impl Header for ProtocolVersion {
 650    fn name() -> &'static HeaderName {
 651        &ZED_PROTOCOL_VERSION
 652    }
 653
 654    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 655    where
 656        Self: Sized,
 657        I: Iterator<Item = &'i axum::http::HeaderValue>,
 658    {
 659        let version = values
 660            .next()
 661            .ok_or_else(axum::headers::Error::invalid)?
 662            .to_str()
 663            .map_err(|_| axum::headers::Error::invalid())?
 664            .parse()
 665            .map_err(|_| axum::headers::Error::invalid())?;
 666        Ok(Self(version))
 667    }
 668
 669    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 670        values.extend([self.0.to_string().parse().unwrap()]);
 671    }
 672}
 673
 674pub fn routes(server: Arc<Server>) -> Router<Body> {
 675    Router::new()
 676        .route("/rpc", get(handle_websocket_request))
 677        .layer(
 678            ServiceBuilder::new()
 679                .layer(Extension(server.app_state.clone()))
 680                .layer(middleware::from_fn(auth::validate_header)),
 681        )
 682        .route("/metrics", get(handle_metrics))
 683        .layer(Extension(server))
 684}
 685
 686pub async fn handle_websocket_request(
 687    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 688    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 689    Extension(server): Extension<Arc<Server>>,
 690    Extension(user): Extension<User>,
 691    ws: WebSocketUpgrade,
 692) -> axum::response::Response {
 693    if protocol_version != rpc::PROTOCOL_VERSION {
 694        return (
 695            StatusCode::UPGRADE_REQUIRED,
 696            "client must be upgraded".to_string(),
 697        )
 698            .into_response();
 699    }
 700    let socket_address = socket_address.to_string();
 701    ws.on_upgrade(move |socket| {
 702        use util::ResultExt;
 703        let socket = socket
 704            .map_ok(to_tungstenite_message)
 705            .err_into()
 706            .with(|message| async move { Ok(to_axum_message(message)) });
 707        let connection = Connection::new(Box::pin(socket));
 708        async move {
 709            server
 710                .handle_connection(connection, socket_address, user, None, Executor::Production)
 711                .await
 712                .log_err();
 713        }
 714    })
 715}
 716
 717pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 718    let connections = server
 719        .connection_pool
 720        .lock()
 721        .await
 722        .connections()
 723        .filter(|connection| !connection.admin)
 724        .count();
 725
 726    METRIC_CONNECTIONS.set(connections as _);
 727
 728    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 729    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 730
 731    let encoder = prometheus::TextEncoder::new();
 732    let metric_families = prometheus::gather();
 733    let encoded_metrics = encoder
 734        .encode_to_string(&metric_families)
 735        .map_err(|err| anyhow!("{}", err))?;
 736    Ok(encoded_metrics)
 737}
 738
 739#[instrument(err, skip(executor))]
 740async fn sign_out(
 741    session: Session,
 742    mut teardown: watch::Receiver<()>,
 743    executor: Executor,
 744) -> Result<()> {
 745    session.peer.disconnect(session.connection_id);
 746    session
 747        .connection_pool()
 748        .await
 749        .remove_connection(session.connection_id)?;
 750
 751    if let Some(mut left_projects) = session
 752        .db()
 753        .await
 754        .connection_lost(session.connection_id)
 755        .await
 756        .trace_err()
 757    {
 758        for left_project in mem::take(&mut *left_projects) {
 759            project_left(&left_project, &session);
 760        }
 761    }
 762
 763    futures::select_biased! {
 764        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 765            leave_room_for_session(&session).await.trace_err();
 766
 767            if !session
 768                .connection_pool()
 769                .await
 770                .is_user_online(session.user_id)
 771            {
 772                let db = session.db().await;
 773                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 774                    room_updated(&room, &session.peer);
 775                }
 776            }
 777            update_user_contacts(session.user_id, &session).await?;
 778        }
 779        _ = teardown.changed().fuse() => {}
 780    }
 781
 782    Ok(())
 783}
 784
 785async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 786    response.send(proto::Ack {})?;
 787    Ok(())
 788}
 789
 790async fn create_room(
 791    _request: proto::CreateRoom,
 792    response: Response<proto::CreateRoom>,
 793    session: Session,
 794) -> Result<()> {
 795    let live_kit_room = nanoid::nanoid!(30);
 796    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 797        if let Some(_) = live_kit
 798            .create_room(live_kit_room.clone())
 799            .await
 800            .trace_err()
 801        {
 802            if let Some(token) = live_kit
 803                .room_token(&live_kit_room, &session.connection_id.to_string())
 804                .trace_err()
 805            {
 806                Some(proto::LiveKitConnectionInfo {
 807                    server_url: live_kit.url().into(),
 808                    token,
 809                })
 810            } else {
 811                None
 812            }
 813        } else {
 814            None
 815        }
 816    } else {
 817        None
 818    };
 819
 820    {
 821        let room = session
 822            .db()
 823            .await
 824            .create_room(session.user_id, session.connection_id, &live_kit_room)
 825            .await?;
 826
 827        response.send(proto::CreateRoomResponse {
 828            room: Some(room.clone()),
 829            live_kit_connection_info,
 830        })?;
 831    }
 832
 833    update_user_contacts(session.user_id, &session).await?;
 834    Ok(())
 835}
 836
 837async fn join_room(
 838    request: proto::JoinRoom,
 839    response: Response<proto::JoinRoom>,
 840    session: Session,
 841) -> Result<()> {
 842    let room_id = RoomId::from_proto(request.id);
 843    let room = {
 844        let room = session
 845            .db()
 846            .await
 847            .join_room(room_id, session.user_id, session.connection_id)
 848            .await?;
 849        room_updated(&room, &session.peer);
 850        room.clone()
 851    };
 852
 853    for connection_id in session
 854        .connection_pool()
 855        .await
 856        .user_connection_ids(session.user_id)
 857    {
 858        session
 859            .peer
 860            .send(
 861                connection_id,
 862                proto::CallCanceled {
 863                    room_id: room_id.to_proto(),
 864                },
 865            )
 866            .trace_err();
 867    }
 868
 869    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 870        if let Some(token) = live_kit
 871            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 872            .trace_err()
 873        {
 874            Some(proto::LiveKitConnectionInfo {
 875                server_url: live_kit.url().into(),
 876                token,
 877            })
 878        } else {
 879            None
 880        }
 881    } else {
 882        None
 883    };
 884
 885    response.send(proto::JoinRoomResponse {
 886        room: Some(room),
 887        live_kit_connection_info,
 888    })?;
 889
 890    update_user_contacts(session.user_id, &session).await?;
 891    Ok(())
 892}
 893
 894async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 895    leave_room_for_session(&session).await
 896}
 897
 898async fn call(
 899    request: proto::Call,
 900    response: Response<proto::Call>,
 901    session: Session,
 902) -> Result<()> {
 903    let room_id = RoomId::from_proto(request.room_id);
 904    let calling_user_id = session.user_id;
 905    let calling_connection_id = session.connection_id;
 906    let called_user_id = UserId::from_proto(request.called_user_id);
 907    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 908    if !session
 909        .db()
 910        .await
 911        .has_contact(calling_user_id, called_user_id)
 912        .await?
 913    {
 914        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 915    }
 916
 917    let incoming_call = {
 918        let (room, incoming_call) = &mut *session
 919            .db()
 920            .await
 921            .call(
 922                room_id,
 923                calling_user_id,
 924                calling_connection_id,
 925                called_user_id,
 926                initial_project_id,
 927            )
 928            .await?;
 929        room_updated(&room, &session.peer);
 930        mem::take(incoming_call)
 931    };
 932    update_user_contacts(called_user_id, &session).await?;
 933
 934    let mut calls = session
 935        .connection_pool()
 936        .await
 937        .user_connection_ids(called_user_id)
 938        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 939        .collect::<FuturesUnordered<_>>();
 940
 941    while let Some(call_response) = calls.next().await {
 942        match call_response.as_ref() {
 943            Ok(_) => {
 944                response.send(proto::Ack {})?;
 945                return Ok(());
 946            }
 947            Err(_) => {
 948                call_response.trace_err();
 949            }
 950        }
 951    }
 952
 953    {
 954        let room = session
 955            .db()
 956            .await
 957            .call_failed(room_id, called_user_id)
 958            .await?;
 959        room_updated(&room, &session.peer);
 960    }
 961    update_user_contacts(called_user_id, &session).await?;
 962
 963    Err(anyhow!("failed to ring user"))?
 964}
 965
 966async fn cancel_call(
 967    request: proto::CancelCall,
 968    response: Response<proto::CancelCall>,
 969    session: Session,
 970) -> Result<()> {
 971    let called_user_id = UserId::from_proto(request.called_user_id);
 972    let room_id = RoomId::from_proto(request.room_id);
 973    {
 974        let room = session
 975            .db()
 976            .await
 977            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 978            .await?;
 979        room_updated(&room, &session.peer);
 980    }
 981
 982    for connection_id in session
 983        .connection_pool()
 984        .await
 985        .user_connection_ids(called_user_id)
 986    {
 987        session
 988            .peer
 989            .send(
 990                connection_id,
 991                proto::CallCanceled {
 992                    room_id: room_id.to_proto(),
 993                },
 994            )
 995            .trace_err();
 996    }
 997    response.send(proto::Ack {})?;
 998
 999    update_user_contacts(called_user_id, &session).await?;
1000    Ok(())
1001}
1002
1003async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1004    let room_id = RoomId::from_proto(message.room_id);
1005    {
1006        let room = session
1007            .db()
1008            .await
1009            .decline_call(Some(room_id), session.user_id)
1010            .await?;
1011        room_updated(&room, &session.peer);
1012    }
1013
1014    for connection_id in session
1015        .connection_pool()
1016        .await
1017        .user_connection_ids(session.user_id)
1018    {
1019        session
1020            .peer
1021            .send(
1022                connection_id,
1023                proto::CallCanceled {
1024                    room_id: room_id.to_proto(),
1025                },
1026            )
1027            .trace_err();
1028    }
1029    update_user_contacts(session.user_id, &session).await?;
1030    Ok(())
1031}
1032
1033async fn update_participant_location(
1034    request: proto::UpdateParticipantLocation,
1035    response: Response<proto::UpdateParticipantLocation>,
1036    session: Session,
1037) -> Result<()> {
1038    let room_id = RoomId::from_proto(request.room_id);
1039    let location = request
1040        .location
1041        .ok_or_else(|| anyhow!("invalid location"))?;
1042    let room = session
1043        .db()
1044        .await
1045        .update_room_participant_location(room_id, session.connection_id, location)
1046        .await?;
1047    room_updated(&room, &session.peer);
1048    response.send(proto::Ack {})?;
1049    Ok(())
1050}
1051
1052async fn share_project(
1053    request: proto::ShareProject,
1054    response: Response<proto::ShareProject>,
1055    session: Session,
1056) -> Result<()> {
1057    let (project_id, room) = &*session
1058        .db()
1059        .await
1060        .share_project(
1061            RoomId::from_proto(request.room_id),
1062            session.connection_id,
1063            &request.worktrees,
1064        )
1065        .await?;
1066    response.send(proto::ShareProjectResponse {
1067        project_id: project_id.to_proto(),
1068    })?;
1069    room_updated(&room, &session.peer);
1070
1071    Ok(())
1072}
1073
1074async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1075    let project_id = ProjectId::from_proto(message.project_id);
1076
1077    let (room, guest_connection_ids) = &*session
1078        .db()
1079        .await
1080        .unshare_project(project_id, session.connection_id)
1081        .await?;
1082
1083    broadcast(
1084        session.connection_id,
1085        guest_connection_ids.iter().copied(),
1086        |conn_id| session.peer.send(conn_id, message.clone()),
1087    );
1088    room_updated(&room, &session.peer);
1089
1090    Ok(())
1091}
1092
1093async fn join_project(
1094    request: proto::JoinProject,
1095    response: Response<proto::JoinProject>,
1096    session: Session,
1097) -> Result<()> {
1098    let project_id = ProjectId::from_proto(request.project_id);
1099    let guest_user_id = session.user_id;
1100
1101    tracing::info!(%project_id, "join project");
1102
1103    let (project, replica_id) = &mut *session
1104        .db()
1105        .await
1106        .join_project(project_id, session.connection_id)
1107        .await?;
1108
1109    let collaborators = project
1110        .collaborators
1111        .iter()
1112        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1113        .map(|collaborator| proto::Collaborator {
1114            peer_id: collaborator.connection_id as u32,
1115            replica_id: collaborator.replica_id.0 as u32,
1116            user_id: collaborator.user_id.to_proto(),
1117        })
1118        .collect::<Vec<_>>();
1119    let worktrees = project
1120        .worktrees
1121        .iter()
1122        .map(|(id, worktree)| proto::WorktreeMetadata {
1123            id: *id,
1124            root_name: worktree.root_name.clone(),
1125            visible: worktree.visible,
1126            abs_path: worktree.abs_path.clone(),
1127        })
1128        .collect::<Vec<_>>();
1129
1130    for collaborator in &collaborators {
1131        session
1132            .peer
1133            .send(
1134                ConnectionId(collaborator.peer_id),
1135                proto::AddProjectCollaborator {
1136                    project_id: project_id.to_proto(),
1137                    collaborator: Some(proto::Collaborator {
1138                        peer_id: session.connection_id.0,
1139                        replica_id: replica_id.0 as u32,
1140                        user_id: guest_user_id.to_proto(),
1141                    }),
1142                },
1143            )
1144            .trace_err();
1145    }
1146
1147    // First, we send the metadata associated with each worktree.
1148    response.send(proto::JoinProjectResponse {
1149        worktrees: worktrees.clone(),
1150        replica_id: replica_id.0 as u32,
1151        collaborators: collaborators.clone(),
1152        language_servers: project.language_servers.clone(),
1153    })?;
1154
1155    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1156        #[cfg(any(test, feature = "test-support"))]
1157        const MAX_CHUNK_SIZE: usize = 2;
1158        #[cfg(not(any(test, feature = "test-support")))]
1159        const MAX_CHUNK_SIZE: usize = 256;
1160
1161        // Stream this worktree's entries.
1162        let message = proto::UpdateWorktree {
1163            project_id: project_id.to_proto(),
1164            worktree_id,
1165            abs_path: worktree.abs_path.clone(),
1166            root_name: worktree.root_name,
1167            updated_entries: worktree.entries,
1168            removed_entries: Default::default(),
1169            scan_id: worktree.scan_id,
1170            is_last_update: worktree.is_complete,
1171        };
1172        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1173            session.peer.send(session.connection_id, update.clone())?;
1174        }
1175
1176        // Stream this worktree's diagnostics.
1177        for summary in worktree.diagnostic_summaries {
1178            session.peer.send(
1179                session.connection_id,
1180                proto::UpdateDiagnosticSummary {
1181                    project_id: project_id.to_proto(),
1182                    worktree_id: worktree.id,
1183                    summary: Some(summary),
1184                },
1185            )?;
1186        }
1187    }
1188
1189    for language_server in &project.language_servers {
1190        session.peer.send(
1191            session.connection_id,
1192            proto::UpdateLanguageServer {
1193                project_id: project_id.to_proto(),
1194                language_server_id: language_server.id,
1195                variant: Some(
1196                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1197                        proto::LspDiskBasedDiagnosticsUpdated {},
1198                    ),
1199                ),
1200            },
1201        )?;
1202    }
1203
1204    Ok(())
1205}
1206
1207async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1208    let sender_id = session.connection_id;
1209    let project_id = ProjectId::from_proto(request.project_id);
1210
1211    let project = session
1212        .db()
1213        .await
1214        .leave_project(project_id, sender_id)
1215        .await?;
1216    tracing::info!(
1217        %project_id,
1218        host_user_id = %project.host_user_id,
1219        host_connection_id = %project.host_connection_id,
1220        "leave project"
1221    );
1222    project_left(&project, &session);
1223
1224    Ok(())
1225}
1226
1227async fn update_project(
1228    request: proto::UpdateProject,
1229    response: Response<proto::UpdateProject>,
1230    session: Session,
1231) -> Result<()> {
1232    let project_id = ProjectId::from_proto(request.project_id);
1233    let (room, guest_connection_ids) = &*session
1234        .db()
1235        .await
1236        .update_project(project_id, session.connection_id, &request.worktrees)
1237        .await?;
1238    broadcast(
1239        session.connection_id,
1240        guest_connection_ids.iter().copied(),
1241        |connection_id| {
1242            session
1243                .peer
1244                .forward_send(session.connection_id, connection_id, request.clone())
1245        },
1246    );
1247    room_updated(&room, &session.peer);
1248    response.send(proto::Ack {})?;
1249
1250    Ok(())
1251}
1252
1253async fn update_worktree(
1254    request: proto::UpdateWorktree,
1255    response: Response<proto::UpdateWorktree>,
1256    session: Session,
1257) -> Result<()> {
1258    let guest_connection_ids = session
1259        .db()
1260        .await
1261        .update_worktree(&request, session.connection_id)
1262        .await?;
1263
1264    broadcast(
1265        session.connection_id,
1266        guest_connection_ids.iter().copied(),
1267        |connection_id| {
1268            session
1269                .peer
1270                .forward_send(session.connection_id, connection_id, request.clone())
1271        },
1272    );
1273    response.send(proto::Ack {})?;
1274    Ok(())
1275}
1276
1277async fn update_diagnostic_summary(
1278    message: proto::UpdateDiagnosticSummary,
1279    session: Session,
1280) -> Result<()> {
1281    let guest_connection_ids = session
1282        .db()
1283        .await
1284        .update_diagnostic_summary(&message, session.connection_id)
1285        .await?;
1286
1287    broadcast(
1288        session.connection_id,
1289        guest_connection_ids.iter().copied(),
1290        |connection_id| {
1291            session
1292                .peer
1293                .forward_send(session.connection_id, connection_id, message.clone())
1294        },
1295    );
1296
1297    Ok(())
1298}
1299
1300async fn start_language_server(
1301    request: proto::StartLanguageServer,
1302    session: Session,
1303) -> Result<()> {
1304    let guest_connection_ids = session
1305        .db()
1306        .await
1307        .start_language_server(&request, session.connection_id)
1308        .await?;
1309
1310    broadcast(
1311        session.connection_id,
1312        guest_connection_ids.iter().copied(),
1313        |connection_id| {
1314            session
1315                .peer
1316                .forward_send(session.connection_id, connection_id, request.clone())
1317        },
1318    );
1319    Ok(())
1320}
1321
1322async fn update_language_server(
1323    request: proto::UpdateLanguageServer,
1324    session: Session,
1325) -> Result<()> {
1326    let project_id = ProjectId::from_proto(request.project_id);
1327    let project_connection_ids = session
1328        .db()
1329        .await
1330        .project_connection_ids(project_id, session.connection_id)
1331        .await?;
1332    broadcast(
1333        session.connection_id,
1334        project_connection_ids.iter().copied(),
1335        |connection_id| {
1336            session
1337                .peer
1338                .forward_send(session.connection_id, connection_id, request.clone())
1339        },
1340    );
1341    Ok(())
1342}
1343
1344async fn forward_project_request<T>(
1345    request: T,
1346    response: Response<T>,
1347    session: Session,
1348) -> Result<()>
1349where
1350    T: EntityMessage + RequestMessage,
1351{
1352    let project_id = ProjectId::from_proto(request.remote_entity_id());
1353    let host_connection_id = {
1354        let collaborators = session
1355            .db()
1356            .await
1357            .project_collaborators(project_id, session.connection_id)
1358            .await?;
1359        ConnectionId(
1360            collaborators
1361                .iter()
1362                .find(|collaborator| collaborator.is_host)
1363                .ok_or_else(|| anyhow!("host not found"))?
1364                .connection_id as u32,
1365        )
1366    };
1367
1368    let payload = session
1369        .peer
1370        .forward_request(session.connection_id, host_connection_id, request)
1371        .await?;
1372
1373    response.send(payload)?;
1374    Ok(())
1375}
1376
1377async fn save_buffer(
1378    request: proto::SaveBuffer,
1379    response: Response<proto::SaveBuffer>,
1380    session: Session,
1381) -> Result<()> {
1382    let project_id = ProjectId::from_proto(request.project_id);
1383    let host_connection_id = {
1384        let collaborators = session
1385            .db()
1386            .await
1387            .project_collaborators(project_id, session.connection_id)
1388            .await?;
1389        let host = collaborators
1390            .iter()
1391            .find(|collaborator| collaborator.is_host)
1392            .ok_or_else(|| anyhow!("host not found"))?;
1393        ConnectionId(host.connection_id as u32)
1394    };
1395    let response_payload = session
1396        .peer
1397        .forward_request(session.connection_id, host_connection_id, request.clone())
1398        .await?;
1399
1400    let mut collaborators = session
1401        .db()
1402        .await
1403        .project_collaborators(project_id, session.connection_id)
1404        .await?;
1405    collaborators
1406        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1407    let project_connection_ids = collaborators
1408        .iter()
1409        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1410    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1411        session
1412            .peer
1413            .forward_send(host_connection_id, conn_id, response_payload.clone())
1414    });
1415    response.send(response_payload)?;
1416    Ok(())
1417}
1418
1419async fn create_buffer_for_peer(
1420    request: proto::CreateBufferForPeer,
1421    session: Session,
1422) -> Result<()> {
1423    session.peer.forward_send(
1424        session.connection_id,
1425        ConnectionId(request.peer_id),
1426        request,
1427    )?;
1428    Ok(())
1429}
1430
1431async fn update_buffer(
1432    request: proto::UpdateBuffer,
1433    response: Response<proto::UpdateBuffer>,
1434    session: Session,
1435) -> Result<()> {
1436    let project_id = ProjectId::from_proto(request.project_id);
1437    let project_connection_ids = session
1438        .db()
1439        .await
1440        .project_connection_ids(project_id, session.connection_id)
1441        .await?;
1442
1443    broadcast(
1444        session.connection_id,
1445        project_connection_ids.iter().copied(),
1446        |connection_id| {
1447            session
1448                .peer
1449                .forward_send(session.connection_id, connection_id, request.clone())
1450        },
1451    );
1452    response.send(proto::Ack {})?;
1453    Ok(())
1454}
1455
1456async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1457    let project_id = ProjectId::from_proto(request.project_id);
1458    let project_connection_ids = session
1459        .db()
1460        .await
1461        .project_connection_ids(project_id, session.connection_id)
1462        .await?;
1463
1464    broadcast(
1465        session.connection_id,
1466        project_connection_ids.iter().copied(),
1467        |connection_id| {
1468            session
1469                .peer
1470                .forward_send(session.connection_id, connection_id, request.clone())
1471        },
1472    );
1473    Ok(())
1474}
1475
1476async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1477    let project_id = ProjectId::from_proto(request.project_id);
1478    let project_connection_ids = session
1479        .db()
1480        .await
1481        .project_connection_ids(project_id, session.connection_id)
1482        .await?;
1483    broadcast(
1484        session.connection_id,
1485        project_connection_ids.iter().copied(),
1486        |connection_id| {
1487            session
1488                .peer
1489                .forward_send(session.connection_id, connection_id, request.clone())
1490        },
1491    );
1492    Ok(())
1493}
1494
1495async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1496    let project_id = ProjectId::from_proto(request.project_id);
1497    let project_connection_ids = session
1498        .db()
1499        .await
1500        .project_connection_ids(project_id, session.connection_id)
1501        .await?;
1502    broadcast(
1503        session.connection_id,
1504        project_connection_ids.iter().copied(),
1505        |connection_id| {
1506            session
1507                .peer
1508                .forward_send(session.connection_id, connection_id, request.clone())
1509        },
1510    );
1511    Ok(())
1512}
1513
1514async fn follow(
1515    request: proto::Follow,
1516    response: Response<proto::Follow>,
1517    session: Session,
1518) -> Result<()> {
1519    let project_id = ProjectId::from_proto(request.project_id);
1520    let leader_id = ConnectionId(request.leader_id);
1521    let follower_id = session.connection_id;
1522    {
1523        let project_connection_ids = session
1524            .db()
1525            .await
1526            .project_connection_ids(project_id, session.connection_id)
1527            .await?;
1528
1529        if !project_connection_ids.contains(&leader_id) {
1530            Err(anyhow!("no such peer"))?;
1531        }
1532    }
1533
1534    let mut response_payload = session
1535        .peer
1536        .forward_request(session.connection_id, leader_id, request)
1537        .await?;
1538    response_payload
1539        .views
1540        .retain(|view| view.leader_id != Some(follower_id.0));
1541    response.send(response_payload)?;
1542    Ok(())
1543}
1544
1545async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1546    let project_id = ProjectId::from_proto(request.project_id);
1547    let leader_id = ConnectionId(request.leader_id);
1548    let project_connection_ids = session
1549        .db()
1550        .await
1551        .project_connection_ids(project_id, session.connection_id)
1552        .await?;
1553    if !project_connection_ids.contains(&leader_id) {
1554        Err(anyhow!("no such peer"))?;
1555    }
1556    session
1557        .peer
1558        .forward_send(session.connection_id, leader_id, request)?;
1559    Ok(())
1560}
1561
1562async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1563    let project_id = ProjectId::from_proto(request.project_id);
1564    let project_connection_ids = session
1565        .db
1566        .lock()
1567        .await
1568        .project_connection_ids(project_id, session.connection_id)
1569        .await?;
1570
1571    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1572        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1573        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1574        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1575    });
1576    for follower_id in &request.follower_ids {
1577        let follower_id = ConnectionId(*follower_id);
1578        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1579            session
1580                .peer
1581                .forward_send(session.connection_id, follower_id, request.clone())?;
1582        }
1583    }
1584    Ok(())
1585}
1586
1587async fn get_users(
1588    request: proto::GetUsers,
1589    response: Response<proto::GetUsers>,
1590    session: Session,
1591) -> Result<()> {
1592    let user_ids = request
1593        .user_ids
1594        .into_iter()
1595        .map(UserId::from_proto)
1596        .collect();
1597    let users = session
1598        .db()
1599        .await
1600        .get_users_by_ids(user_ids)
1601        .await?
1602        .into_iter()
1603        .map(|user| proto::User {
1604            id: user.id.to_proto(),
1605            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1606            github_login: user.github_login,
1607        })
1608        .collect();
1609    response.send(proto::UsersResponse { users })?;
1610    Ok(())
1611}
1612
1613async fn fuzzy_search_users(
1614    request: proto::FuzzySearchUsers,
1615    response: Response<proto::FuzzySearchUsers>,
1616    session: Session,
1617) -> Result<()> {
1618    let query = request.query;
1619    let users = match query.len() {
1620        0 => vec![],
1621        1 | 2 => session
1622            .db()
1623            .await
1624            .get_user_by_github_account(&query, None)
1625            .await?
1626            .into_iter()
1627            .collect(),
1628        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1629    };
1630    let users = users
1631        .into_iter()
1632        .filter(|user| user.id != session.user_id)
1633        .map(|user| proto::User {
1634            id: user.id.to_proto(),
1635            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1636            github_login: user.github_login,
1637        })
1638        .collect();
1639    response.send(proto::UsersResponse { users })?;
1640    Ok(())
1641}
1642
1643async fn request_contact(
1644    request: proto::RequestContact,
1645    response: Response<proto::RequestContact>,
1646    session: Session,
1647) -> Result<()> {
1648    let requester_id = session.user_id;
1649    let responder_id = UserId::from_proto(request.responder_id);
1650    if requester_id == responder_id {
1651        return Err(anyhow!("cannot add yourself as a contact"))?;
1652    }
1653
1654    session
1655        .db()
1656        .await
1657        .send_contact_request(requester_id, responder_id)
1658        .await?;
1659
1660    // Update outgoing contact requests of requester
1661    let mut update = proto::UpdateContacts::default();
1662    update.outgoing_requests.push(responder_id.to_proto());
1663    for connection_id in session
1664        .connection_pool()
1665        .await
1666        .user_connection_ids(requester_id)
1667    {
1668        session.peer.send(connection_id, update.clone())?;
1669    }
1670
1671    // Update incoming contact requests of responder
1672    let mut update = proto::UpdateContacts::default();
1673    update
1674        .incoming_requests
1675        .push(proto::IncomingContactRequest {
1676            requester_id: requester_id.to_proto(),
1677            should_notify: true,
1678        });
1679    for connection_id in session
1680        .connection_pool()
1681        .await
1682        .user_connection_ids(responder_id)
1683    {
1684        session.peer.send(connection_id, update.clone())?;
1685    }
1686
1687    response.send(proto::Ack {})?;
1688    Ok(())
1689}
1690
1691async fn respond_to_contact_request(
1692    request: proto::RespondToContactRequest,
1693    response: Response<proto::RespondToContactRequest>,
1694    session: Session,
1695) -> Result<()> {
1696    let responder_id = session.user_id;
1697    let requester_id = UserId::from_proto(request.requester_id);
1698    let db = session.db().await;
1699    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1700        db.dismiss_contact_notification(responder_id, requester_id)
1701            .await?;
1702    } else {
1703        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1704
1705        db.respond_to_contact_request(responder_id, requester_id, accept)
1706            .await?;
1707        let requester_busy = db.is_user_busy(requester_id).await?;
1708        let responder_busy = db.is_user_busy(responder_id).await?;
1709
1710        let pool = session.connection_pool().await;
1711        // Update responder with new contact
1712        let mut update = proto::UpdateContacts::default();
1713        if accept {
1714            update
1715                .contacts
1716                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1717        }
1718        update
1719            .remove_incoming_requests
1720            .push(requester_id.to_proto());
1721        for connection_id in pool.user_connection_ids(responder_id) {
1722            session.peer.send(connection_id, update.clone())?;
1723        }
1724
1725        // Update requester with new contact
1726        let mut update = proto::UpdateContacts::default();
1727        if accept {
1728            update
1729                .contacts
1730                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1731        }
1732        update
1733            .remove_outgoing_requests
1734            .push(responder_id.to_proto());
1735        for connection_id in pool.user_connection_ids(requester_id) {
1736            session.peer.send(connection_id, update.clone())?;
1737        }
1738    }
1739
1740    response.send(proto::Ack {})?;
1741    Ok(())
1742}
1743
1744async fn remove_contact(
1745    request: proto::RemoveContact,
1746    response: Response<proto::RemoveContact>,
1747    session: Session,
1748) -> Result<()> {
1749    let requester_id = session.user_id;
1750    let responder_id = UserId::from_proto(request.user_id);
1751    let db = session.db().await;
1752    db.remove_contact(requester_id, responder_id).await?;
1753
1754    let pool = session.connection_pool().await;
1755    // Update outgoing contact requests of requester
1756    let mut update = proto::UpdateContacts::default();
1757    update
1758        .remove_outgoing_requests
1759        .push(responder_id.to_proto());
1760    for connection_id in pool.user_connection_ids(requester_id) {
1761        session.peer.send(connection_id, update.clone())?;
1762    }
1763
1764    // Update incoming contact requests of responder
1765    let mut update = proto::UpdateContacts::default();
1766    update
1767        .remove_incoming_requests
1768        .push(requester_id.to_proto());
1769    for connection_id in pool.user_connection_ids(responder_id) {
1770        session.peer.send(connection_id, update.clone())?;
1771    }
1772
1773    response.send(proto::Ack {})?;
1774    Ok(())
1775}
1776
1777async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1778    let project_id = ProjectId::from_proto(request.project_id);
1779    let project_connection_ids = session
1780        .db()
1781        .await
1782        .project_connection_ids(project_id, session.connection_id)
1783        .await?;
1784    broadcast(
1785        session.connection_id,
1786        project_connection_ids.iter().copied(),
1787        |connection_id| {
1788            session
1789                .peer
1790                .forward_send(session.connection_id, connection_id, request.clone())
1791        },
1792    );
1793    Ok(())
1794}
1795
1796async fn get_private_user_info(
1797    _request: proto::GetPrivateUserInfo,
1798    response: Response<proto::GetPrivateUserInfo>,
1799    session: Session,
1800) -> Result<()> {
1801    let metrics_id = session
1802        .db()
1803        .await
1804        .get_user_metrics_id(session.user_id)
1805        .await?;
1806    let user = session
1807        .db()
1808        .await
1809        .get_user_by_id(session.user_id)
1810        .await?
1811        .ok_or_else(|| anyhow!("user not found"))?;
1812    response.send(proto::GetPrivateUserInfoResponse {
1813        metrics_id,
1814        staff: user.admin,
1815    })?;
1816    Ok(())
1817}
1818
1819fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1820    match message {
1821        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1822        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1823        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1824        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1825        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1826            code: frame.code.into(),
1827            reason: frame.reason,
1828        })),
1829    }
1830}
1831
1832fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1833    match message {
1834        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1835        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1836        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1837        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1838        AxumMessage::Close(frame) => {
1839            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1840                code: frame.code.into(),
1841                reason: frame.reason,
1842            }))
1843        }
1844    }
1845}
1846
1847fn build_initial_contacts_update(
1848    contacts: Vec<db::Contact>,
1849    pool: &ConnectionPool,
1850) -> proto::UpdateContacts {
1851    let mut update = proto::UpdateContacts::default();
1852
1853    for contact in contacts {
1854        match contact {
1855            db::Contact::Accepted {
1856                user_id,
1857                should_notify,
1858                busy,
1859            } => {
1860                update
1861                    .contacts
1862                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1863            }
1864            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1865            db::Contact::Incoming {
1866                user_id,
1867                should_notify,
1868            } => update
1869                .incoming_requests
1870                .push(proto::IncomingContactRequest {
1871                    requester_id: user_id.to_proto(),
1872                    should_notify,
1873                }),
1874        }
1875    }
1876
1877    update
1878}
1879
1880fn contact_for_user(
1881    user_id: UserId,
1882    should_notify: bool,
1883    busy: bool,
1884    pool: &ConnectionPool,
1885) -> proto::Contact {
1886    proto::Contact {
1887        user_id: user_id.to_proto(),
1888        online: pool.is_user_online(user_id),
1889        busy,
1890        should_notify,
1891    }
1892}
1893
1894fn room_updated(room: &proto::Room, peer: &Peer) {
1895    for participant in &room.participants {
1896        peer.send(
1897            ConnectionId(participant.peer_id),
1898            proto::RoomUpdated {
1899                room: Some(room.clone()),
1900            },
1901        )
1902        .trace_err();
1903    }
1904}
1905
1906async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1907    let db = session.db().await;
1908    let contacts = db.get_contacts(user_id).await?;
1909    let busy = db.is_user_busy(user_id).await?;
1910
1911    let pool = session.connection_pool().await;
1912    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1913    for contact in contacts {
1914        if let db::Contact::Accepted {
1915            user_id: contact_user_id,
1916            ..
1917        } = contact
1918        {
1919            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1920                session
1921                    .peer
1922                    .send(
1923                        contact_conn_id,
1924                        proto::UpdateContacts {
1925                            contacts: vec![updated_contact.clone()],
1926                            remove_contacts: Default::default(),
1927                            incoming_requests: Default::default(),
1928                            remove_incoming_requests: Default::default(),
1929                            outgoing_requests: Default::default(),
1930                            remove_outgoing_requests: Default::default(),
1931                        },
1932                    )
1933                    .trace_err();
1934            }
1935        }
1936    }
1937    Ok(())
1938}
1939
1940async fn leave_room_for_session(session: &Session) -> Result<()> {
1941    let mut contacts_to_update = HashSet::default();
1942
1943    let room_id;
1944    let canceled_calls_to_user_ids;
1945    let live_kit_room;
1946    let delete_live_kit_room;
1947    {
1948        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1949        contacts_to_update.insert(session.user_id);
1950
1951        for project in left_room.left_projects.values() {
1952            project_left(project, session);
1953        }
1954
1955        room_updated(&left_room.room, &session.peer);
1956        room_id = RoomId::from_proto(left_room.room.id);
1957        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1958        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1959        delete_live_kit_room = left_room.room.participants.is_empty();
1960    }
1961
1962    {
1963        let pool = session.connection_pool().await;
1964        for canceled_user_id in canceled_calls_to_user_ids {
1965            for connection_id in pool.user_connection_ids(canceled_user_id) {
1966                session
1967                    .peer
1968                    .send(
1969                        connection_id,
1970                        proto::CallCanceled {
1971                            room_id: room_id.to_proto(),
1972                        },
1973                    )
1974                    .trace_err();
1975            }
1976            contacts_to_update.insert(canceled_user_id);
1977        }
1978    }
1979
1980    for contact_user_id in contacts_to_update {
1981        update_user_contacts(contact_user_id, &session).await?;
1982    }
1983
1984    if let Some(live_kit) = session.live_kit_client.as_ref() {
1985        live_kit
1986            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1987            .await
1988            .trace_err();
1989
1990        if delete_live_kit_room {
1991            live_kit.delete_room(live_kit_room).await.trace_err();
1992        }
1993    }
1994
1995    Ok(())
1996}
1997
1998fn project_left(project: &db::LeftProject, session: &Session) {
1999    for connection_id in &project.connection_ids {
2000        if project.host_user_id == session.user_id {
2001            session
2002                .peer
2003                .send(
2004                    *connection_id,
2005                    proto::UnshareProject {
2006                        project_id: project.id.to_proto(),
2007                    },
2008                )
2009                .trace_err();
2010        } else {
2011            session
2012                .peer
2013                .send(
2014                    *connection_id,
2015                    proto::RemoveProjectCollaborator {
2016                        project_id: project.id.to_proto(),
2017                        peer_id: session.connection_id.0,
2018                    },
2019                )
2020                .trace_err();
2021        }
2022    }
2023
2024    session
2025        .peer
2026        .send(
2027            session.connection_id,
2028            proto::UnshareProject {
2029                project_id: project.id.to_proto(),
2030            },
2031        )
2032        .trace_err();
2033}
2034
2035pub trait ResultExt {
2036    type Ok;
2037
2038    fn trace_err(self) -> Option<Self::Ok>;
2039}
2040
2041impl<T, E> ResultExt for Result<T, E>
2042where
2043    E: std::fmt::Debug,
2044{
2045    type Ok = T;
2046
2047    fn trace_err(self) -> Option<T> {
2048        match self {
2049            Ok(value) => Some(value),
2050            Err(error) => {
2051                tracing::error!("{:?}", error);
2052                None
2053            }
2054        }
2055    }
2056}