rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    peer: Arc<Peer>,
 142    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    executor: Executor,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146    teardown: watch::Sender<()>,
 147}
 148
 149pub(crate) struct ConnectionPoolGuard<'a> {
 150    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 151    _not_send: PhantomData<Rc<()>>,
 152}
 153
 154#[derive(Serialize)]
 155pub struct ServerSnapshot<'a> {
 156    peer: &'a Peer,
 157    #[serde(serialize_with = "serialize_deref")]
 158    connection_pool: ConnectionPoolGuard<'a>,
 159}
 160
 161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 162where
 163    S: Serializer,
 164    T: Deref<Target = U>,
 165    U: Serialize,
 166{
 167    Serialize::serialize(value.deref(), serializer)
 168}
 169
 170impl Server {
 171    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 172        let mut server = Self {
 173            peer: Peer::new(),
 174            app_state,
 175            executor,
 176            connection_pool: Default::default(),
 177            handlers: Default::default(),
 178            teardown: watch::channel(()).0,
 179        };
 180
 181        server
 182            .add_request_handler(ping)
 183            .add_request_handler(create_room)
 184            .add_request_handler(join_room)
 185            .add_message_handler(leave_room)
 186            .add_request_handler(call)
 187            .add_request_handler(cancel_call)
 188            .add_message_handler(decline_call)
 189            .add_request_handler(update_participant_location)
 190            .add_request_handler(share_project)
 191            .add_message_handler(unshare_project)
 192            .add_request_handler(join_project)
 193            .add_message_handler(leave_project)
 194            .add_request_handler(update_project)
 195            .add_request_handler(update_worktree)
 196            .add_message_handler(start_language_server)
 197            .add_message_handler(update_language_server)
 198            .add_message_handler(update_diagnostic_summary)
 199            .add_request_handler(forward_project_request::<proto::GetHover>)
 200            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 201            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 202            .add_request_handler(forward_project_request::<proto::GetReferences>)
 203            .add_request_handler(forward_project_request::<proto::SearchProject>)
 204            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 205            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 206            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 209            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 210            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 211            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 213            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 214            .add_request_handler(forward_project_request::<proto::PerformRename>)
 215            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 216            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 217            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 218            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 221            .add_message_handler(create_buffer_for_peer)
 222            .add_request_handler(update_buffer)
 223            .add_message_handler(update_buffer_file)
 224            .add_message_handler(buffer_reloaded)
 225            .add_message_handler(buffer_saved)
 226            .add_request_handler(save_buffer)
 227            .add_request_handler(get_users)
 228            .add_request_handler(fuzzy_search_users)
 229            .add_request_handler(request_contact)
 230            .add_request_handler(remove_contact)
 231            .add_request_handler(respond_to_contact_request)
 232            .add_request_handler(follow)
 233            .add_message_handler(unfollow)
 234            .add_message_handler(update_followers)
 235            .add_message_handler(update_diff_base)
 236            .add_request_handler(get_private_user_info);
 237
 238        Arc::new(server)
 239    }
 240
 241    pub fn start(&self) {
 242        let db = self.app_state.db.clone();
 243        let peer = self.peer.clone();
 244        let pool = self.connection_pool.clone();
 245        let live_kit_client = self.app_state.live_kit_client.clone();
 246        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 247        self.executor.spawn_detached(async move {
 248            db.delete_stale_projects().await.trace_err();
 249            timeout.await;
 250            if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
 251                for room_id in room_ids {
 252                    let mut contacts_to_update = HashSet::default();
 253                    let mut canceled_calls_to_user_ids = Vec::new();
 254                    let mut live_kit_room = String::new();
 255                    let mut delete_live_kit_room = false;
 256
 257                    if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 258                        room_updated(&refreshed_room.room, &peer);
 259                        contacts_to_update
 260                            .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 261                        contacts_to_update
 262                            .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 263                        canceled_calls_to_user_ids =
 264                            mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 265                        live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 266                        delete_live_kit_room = refreshed_room.room.participants.is_empty();
 267                    }
 268
 269                    {
 270                        let pool = pool.lock();
 271                        for canceled_user_id in canceled_calls_to_user_ids {
 272                            for connection_id in pool.user_connection_ids(canceled_user_id) {
 273                                peer.send(
 274                                    connection_id,
 275                                    proto::CallCanceled {
 276                                        room_id: room_id.to_proto(),
 277                                    },
 278                                )
 279                                .trace_err();
 280                            }
 281                        }
 282                    }
 283
 284                    for user_id in contacts_to_update {
 285                        let busy = db.is_user_busy(user_id).await.trace_err();
 286                        let contacts = db.get_contacts(user_id).await.trace_err();
 287                        if let Some((busy, contacts)) = busy.zip(contacts) {
 288                            let pool = pool.lock();
 289                            let updated_contact = contact_for_user(user_id, false, busy, &pool);
 290                            for contact in contacts {
 291                                if let db::Contact::Accepted {
 292                                    user_id: contact_user_id,
 293                                    ..
 294                                } = contact
 295                                {
 296                                    for contact_conn_id in pool.user_connection_ids(contact_user_id)
 297                                    {
 298                                        peer.send(
 299                                            contact_conn_id,
 300                                            proto::UpdateContacts {
 301                                                contacts: vec![updated_contact.clone()],
 302                                                remove_contacts: Default::default(),
 303                                                incoming_requests: Default::default(),
 304                                                remove_incoming_requests: Default::default(),
 305                                                outgoing_requests: Default::default(),
 306                                                remove_outgoing_requests: Default::default(),
 307                                            },
 308                                        )
 309                                        .trace_err();
 310                                    }
 311                                }
 312                            }
 313                        }
 314                    }
 315
 316                    if let Some(live_kit) = live_kit_client.as_ref() {
 317                        if delete_live_kit_room {
 318                            live_kit.delete_room(live_kit_room).await.trace_err();
 319                        }
 320                    }
 321                }
 322            }
 323        });
 324    }
 325
 326    pub fn teardown(&self) {
 327        self.peer.reset();
 328        self.connection_pool.lock().reset();
 329        let _ = self.teardown.send(());
 330    }
 331
 332    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 333    where
 334        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 335        Fut: 'static + Send + Future<Output = Result<()>>,
 336        M: EnvelopedMessage,
 337    {
 338        let prev_handler = self.handlers.insert(
 339            TypeId::of::<M>(),
 340            Box::new(move |envelope, session| {
 341                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 342                let span = info_span!(
 343                    "handle message",
 344                    payload_type = envelope.payload_type_name()
 345                );
 346                span.in_scope(|| {
 347                    tracing::info!(
 348                        payload_type = envelope.payload_type_name(),
 349                        "message received"
 350                    );
 351                });
 352                let future = (handler)(*envelope, session);
 353                async move {
 354                    if let Err(error) = future.await {
 355                        tracing::error!(%error, "error handling message");
 356                    }
 357                }
 358                .instrument(span)
 359                .boxed()
 360            }),
 361        );
 362        if prev_handler.is_some() {
 363            panic!("registered a handler for the same message twice");
 364        }
 365        self
 366    }
 367
 368    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 369    where
 370        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 371        Fut: 'static + Send + Future<Output = Result<()>>,
 372        M: EnvelopedMessage,
 373    {
 374        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 375        self
 376    }
 377
 378    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 379    where
 380        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 381        Fut: Send + Future<Output = Result<()>>,
 382        M: RequestMessage,
 383    {
 384        let handler = Arc::new(handler);
 385        self.add_handler(move |envelope, session| {
 386            let receipt = envelope.receipt();
 387            let handler = handler.clone();
 388            async move {
 389                let peer = session.peer.clone();
 390                let responded = Arc::new(AtomicBool::default());
 391                let response = Response {
 392                    peer: peer.clone(),
 393                    responded: responded.clone(),
 394                    receipt,
 395                };
 396                match (handler)(envelope.payload, response, session).await {
 397                    Ok(()) => {
 398                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 399                            Ok(())
 400                        } else {
 401                            Err(anyhow!("handler did not send a response"))?
 402                        }
 403                    }
 404                    Err(error) => {
 405                        peer.respond_with_error(
 406                            receipt,
 407                            proto::Error {
 408                                message: error.to_string(),
 409                            },
 410                        )?;
 411                        Err(error)
 412                    }
 413                }
 414            }
 415        })
 416    }
 417
 418    pub fn handle_connection(
 419        self: &Arc<Self>,
 420        connection: Connection,
 421        address: String,
 422        user: User,
 423        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 424        executor: Executor,
 425    ) -> impl Future<Output = Result<()>> {
 426        let this = self.clone();
 427        let user_id = user.id;
 428        let login = user.github_login;
 429        let span = info_span!("handle connection", %user_id, %login, %address);
 430        let mut teardown = self.teardown.subscribe();
 431        async move {
 432            let (connection_id, handle_io, mut incoming_rx) = this
 433                .peer
 434                .add_connection(connection, {
 435                    let executor = executor.clone();
 436                    move |duration| executor.sleep(duration)
 437                });
 438
 439            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 440            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 441            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 442
 443            if let Some(send_connection_id) = send_connection_id.take() {
 444                let _ = send_connection_id.send(connection_id);
 445            }
 446
 447            if !user.connected_once {
 448                this.peer.send(connection_id, proto::ShowContacts {})?;
 449                this.app_state.db.set_user_connected_once(user_id, true).await?;
 450            }
 451
 452            let (contacts, invite_code) = future::try_join(
 453                this.app_state.db.get_contacts(user_id),
 454                this.app_state.db.get_invite_code_for_user(user_id)
 455            ).await?;
 456
 457            {
 458                let mut pool = this.connection_pool.lock();
 459                pool.add_connection(connection_id, user_id, user.admin);
 460                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 461
 462                if let Some((code, count)) = invite_code {
 463                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 464                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 465                        count: count as u32,
 466                    })?;
 467                }
 468            }
 469
 470            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 471                this.peer.send(connection_id, incoming_call)?;
 472            }
 473
 474            let session = Session {
 475                user_id,
 476                connection_id,
 477                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 478                peer: this.peer.clone(),
 479                connection_pool: this.connection_pool.clone(),
 480                live_kit_client: this.app_state.live_kit_client.clone()
 481            };
 482            update_user_contacts(user_id, &session).await?;
 483
 484            let handle_io = handle_io.fuse();
 485            futures::pin_mut!(handle_io);
 486
 487            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 488            // This prevents deadlocks when e.g., client A performs a request to client B and
 489            // client B performs a request to client A. If both clients stop processing further
 490            // messages until their respective request completes, they won't have a chance to
 491            // respond to the other client's request and cause a deadlock.
 492            //
 493            // This arrangement ensures we will attempt to process earlier messages first, but fall
 494            // back to processing messages arrived later in the spirit of making progress.
 495            let mut foreground_message_handlers = FuturesUnordered::new();
 496            loop {
 497                let next_message = incoming_rx.next().fuse();
 498                futures::pin_mut!(next_message);
 499                futures::select_biased! {
 500                    _ = teardown.changed().fuse() => return Ok(()),
 501                    result = handle_io => {
 502                        if let Err(error) = result {
 503                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 504                        }
 505                        break;
 506                    }
 507                    _ = foreground_message_handlers.next() => {}
 508                    message = next_message => {
 509                        if let Some(message) = message {
 510                            let type_name = message.payload_type_name();
 511                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 512                            let span_enter = span.enter();
 513                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 514                                let is_background = message.is_background();
 515                                let handle_message = (handler)(message, session.clone());
 516                                drop(span_enter);
 517
 518                                let handle_message = handle_message.instrument(span);
 519                                if is_background {
 520                                    executor.spawn_detached(handle_message);
 521                                } else {
 522                                    foreground_message_handlers.push(handle_message);
 523                                }
 524                            } else {
 525                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 526                            }
 527                        } else {
 528                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 529                            break;
 530                        }
 531                    }
 532                }
 533            }
 534
 535            drop(foreground_message_handlers);
 536            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 537            if let Err(error) = sign_out(session, teardown, executor).await {
 538                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 539            }
 540
 541            Ok(())
 542        }.instrument(span)
 543    }
 544
 545    pub async fn invite_code_redeemed(
 546        self: &Arc<Self>,
 547        inviter_id: UserId,
 548        invitee_id: UserId,
 549    ) -> Result<()> {
 550        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 551            if let Some(code) = &user.invite_code {
 552                let pool = self.connection_pool.lock();
 553                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 554                for connection_id in pool.user_connection_ids(inviter_id) {
 555                    self.peer.send(
 556                        connection_id,
 557                        proto::UpdateContacts {
 558                            contacts: vec![invitee_contact.clone()],
 559                            ..Default::default()
 560                        },
 561                    )?;
 562                    self.peer.send(
 563                        connection_id,
 564                        proto::UpdateInviteInfo {
 565                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 566                            count: user.invite_count as u32,
 567                        },
 568                    )?;
 569                }
 570            }
 571        }
 572        Ok(())
 573    }
 574
 575    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 576        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 577            if let Some(invite_code) = &user.invite_code {
 578                let pool = self.connection_pool.lock();
 579                for connection_id in pool.user_connection_ids(user_id) {
 580                    self.peer.send(
 581                        connection_id,
 582                        proto::UpdateInviteInfo {
 583                            url: format!(
 584                                "{}{}",
 585                                self.app_state.config.invite_link_prefix, invite_code
 586                            ),
 587                            count: user.invite_count as u32,
 588                        },
 589                    )?;
 590                }
 591            }
 592        }
 593        Ok(())
 594    }
 595
 596    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 597        ServerSnapshot {
 598            connection_pool: ConnectionPoolGuard {
 599                guard: self.connection_pool.lock(),
 600                _not_send: PhantomData,
 601            },
 602            peer: &self.peer,
 603        }
 604    }
 605}
 606
 607impl<'a> Deref for ConnectionPoolGuard<'a> {
 608    type Target = ConnectionPool;
 609
 610    fn deref(&self) -> &Self::Target {
 611        &*self.guard
 612    }
 613}
 614
 615impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 616    fn deref_mut(&mut self) -> &mut Self::Target {
 617        &mut *self.guard
 618    }
 619}
 620
 621impl<'a> Drop for ConnectionPoolGuard<'a> {
 622    fn drop(&mut self) {
 623        #[cfg(test)]
 624        self.check_invariants();
 625    }
 626}
 627
 628fn broadcast<F>(
 629    sender_id: ConnectionId,
 630    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 631    mut f: F,
 632) where
 633    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 634{
 635    for receiver_id in receiver_ids {
 636        if receiver_id != sender_id {
 637            f(receiver_id).trace_err();
 638        }
 639    }
 640}
 641
 642lazy_static! {
 643    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 644}
 645
 646pub struct ProtocolVersion(u32);
 647
 648impl Header for ProtocolVersion {
 649    fn name() -> &'static HeaderName {
 650        &ZED_PROTOCOL_VERSION
 651    }
 652
 653    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 654    where
 655        Self: Sized,
 656        I: Iterator<Item = &'i axum::http::HeaderValue>,
 657    {
 658        let version = values
 659            .next()
 660            .ok_or_else(axum::headers::Error::invalid)?
 661            .to_str()
 662            .map_err(|_| axum::headers::Error::invalid())?
 663            .parse()
 664            .map_err(|_| axum::headers::Error::invalid())?;
 665        Ok(Self(version))
 666    }
 667
 668    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 669        values.extend([self.0.to_string().parse().unwrap()]);
 670    }
 671}
 672
 673pub fn routes(server: Arc<Server>) -> Router<Body> {
 674    Router::new()
 675        .route("/rpc", get(handle_websocket_request))
 676        .layer(
 677            ServiceBuilder::new()
 678                .layer(Extension(server.app_state.clone()))
 679                .layer(middleware::from_fn(auth::validate_header)),
 680        )
 681        .route("/metrics", get(handle_metrics))
 682        .layer(Extension(server))
 683}
 684
 685pub async fn handle_websocket_request(
 686    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 687    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 688    Extension(server): Extension<Arc<Server>>,
 689    Extension(user): Extension<User>,
 690    ws: WebSocketUpgrade,
 691) -> axum::response::Response {
 692    if protocol_version != rpc::PROTOCOL_VERSION {
 693        return (
 694            StatusCode::UPGRADE_REQUIRED,
 695            "client must be upgraded".to_string(),
 696        )
 697            .into_response();
 698    }
 699    let socket_address = socket_address.to_string();
 700    ws.on_upgrade(move |socket| {
 701        use util::ResultExt;
 702        let socket = socket
 703            .map_ok(to_tungstenite_message)
 704            .err_into()
 705            .with(|message| async move { Ok(to_axum_message(message)) });
 706        let connection = Connection::new(Box::pin(socket));
 707        async move {
 708            server
 709                .handle_connection(connection, socket_address, user, None, Executor::Production)
 710                .await
 711                .log_err();
 712        }
 713    })
 714}
 715
 716pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 717    let connections = server
 718        .connection_pool
 719        .lock()
 720        .connections()
 721        .filter(|connection| !connection.admin)
 722        .count();
 723
 724    METRIC_CONNECTIONS.set(connections as _);
 725
 726    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 727    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 728
 729    let encoder = prometheus::TextEncoder::new();
 730    let metric_families = prometheus::gather();
 731    let encoded_metrics = encoder
 732        .encode_to_string(&metric_families)
 733        .map_err(|err| anyhow!("{}", err))?;
 734    Ok(encoded_metrics)
 735}
 736
 737#[instrument(err, skip(executor))]
 738async fn sign_out(
 739    session: Session,
 740    mut teardown: watch::Receiver<()>,
 741    executor: Executor,
 742) -> Result<()> {
 743    session.peer.disconnect(session.connection_id);
 744    session
 745        .connection_pool()
 746        .await
 747        .remove_connection(session.connection_id)?;
 748
 749    if let Some(mut left_projects) = session
 750        .db()
 751        .await
 752        .connection_lost(session.connection_id)
 753        .await
 754        .trace_err()
 755    {
 756        for left_project in mem::take(&mut *left_projects) {
 757            project_left(&left_project, &session);
 758        }
 759    }
 760
 761    futures::select_biased! {
 762        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 763            leave_room_for_session(&session).await.trace_err();
 764
 765            if !session
 766                .connection_pool()
 767                .await
 768                .is_user_online(session.user_id)
 769            {
 770                let db = session.db().await;
 771                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 772                    room_updated(&room, &session.peer);
 773                }
 774            }
 775            update_user_contacts(session.user_id, &session).await?;
 776        }
 777        _ = teardown.changed().fuse() => {}
 778    }
 779
 780    Ok(())
 781}
 782
 783async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 784    response.send(proto::Ack {})?;
 785    Ok(())
 786}
 787
 788async fn create_room(
 789    _request: proto::CreateRoom,
 790    response: Response<proto::CreateRoom>,
 791    session: Session,
 792) -> Result<()> {
 793    let live_kit_room = nanoid::nanoid!(30);
 794    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 795        if let Some(_) = live_kit
 796            .create_room(live_kit_room.clone())
 797            .await
 798            .trace_err()
 799        {
 800            if let Some(token) = live_kit
 801                .room_token(&live_kit_room, &session.connection_id.to_string())
 802                .trace_err()
 803            {
 804                Some(proto::LiveKitConnectionInfo {
 805                    server_url: live_kit.url().into(),
 806                    token,
 807                })
 808            } else {
 809                None
 810            }
 811        } else {
 812            None
 813        }
 814    } else {
 815        None
 816    };
 817
 818    {
 819        let room = session
 820            .db()
 821            .await
 822            .create_room(session.user_id, session.connection_id, &live_kit_room)
 823            .await?;
 824
 825        response.send(proto::CreateRoomResponse {
 826            room: Some(room.clone()),
 827            live_kit_connection_info,
 828        })?;
 829    }
 830
 831    update_user_contacts(session.user_id, &session).await?;
 832    Ok(())
 833}
 834
 835async fn join_room(
 836    request: proto::JoinRoom,
 837    response: Response<proto::JoinRoom>,
 838    session: Session,
 839) -> Result<()> {
 840    let room_id = RoomId::from_proto(request.id);
 841    let room = {
 842        let room = session
 843            .db()
 844            .await
 845            .join_room(room_id, session.user_id, session.connection_id)
 846            .await?;
 847        room_updated(&room, &session.peer);
 848        room.clone()
 849    };
 850
 851    for connection_id in session
 852        .connection_pool()
 853        .await
 854        .user_connection_ids(session.user_id)
 855    {
 856        session
 857            .peer
 858            .send(
 859                connection_id,
 860                proto::CallCanceled {
 861                    room_id: room_id.to_proto(),
 862                },
 863            )
 864            .trace_err();
 865    }
 866
 867    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 868        if let Some(token) = live_kit
 869            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 870            .trace_err()
 871        {
 872            Some(proto::LiveKitConnectionInfo {
 873                server_url: live_kit.url().into(),
 874                token,
 875            })
 876        } else {
 877            None
 878        }
 879    } else {
 880        None
 881    };
 882
 883    response.send(proto::JoinRoomResponse {
 884        room: Some(room),
 885        live_kit_connection_info,
 886    })?;
 887
 888    update_user_contacts(session.user_id, &session).await?;
 889    Ok(())
 890}
 891
 892async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 893    leave_room_for_session(&session).await
 894}
 895
 896async fn call(
 897    request: proto::Call,
 898    response: Response<proto::Call>,
 899    session: Session,
 900) -> Result<()> {
 901    let room_id = RoomId::from_proto(request.room_id);
 902    let calling_user_id = session.user_id;
 903    let calling_connection_id = session.connection_id;
 904    let called_user_id = UserId::from_proto(request.called_user_id);
 905    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 906    if !session
 907        .db()
 908        .await
 909        .has_contact(calling_user_id, called_user_id)
 910        .await?
 911    {
 912        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 913    }
 914
 915    let incoming_call = {
 916        let (room, incoming_call) = &mut *session
 917            .db()
 918            .await
 919            .call(
 920                room_id,
 921                calling_user_id,
 922                calling_connection_id,
 923                called_user_id,
 924                initial_project_id,
 925            )
 926            .await?;
 927        room_updated(&room, &session.peer);
 928        mem::take(incoming_call)
 929    };
 930    update_user_contacts(called_user_id, &session).await?;
 931
 932    let mut calls = session
 933        .connection_pool()
 934        .await
 935        .user_connection_ids(called_user_id)
 936        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 937        .collect::<FuturesUnordered<_>>();
 938
 939    while let Some(call_response) = calls.next().await {
 940        match call_response.as_ref() {
 941            Ok(_) => {
 942                response.send(proto::Ack {})?;
 943                return Ok(());
 944            }
 945            Err(_) => {
 946                call_response.trace_err();
 947            }
 948        }
 949    }
 950
 951    {
 952        let room = session
 953            .db()
 954            .await
 955            .call_failed(room_id, called_user_id)
 956            .await?;
 957        room_updated(&room, &session.peer);
 958    }
 959    update_user_contacts(called_user_id, &session).await?;
 960
 961    Err(anyhow!("failed to ring user"))?
 962}
 963
 964async fn cancel_call(
 965    request: proto::CancelCall,
 966    response: Response<proto::CancelCall>,
 967    session: Session,
 968) -> Result<()> {
 969    let called_user_id = UserId::from_proto(request.called_user_id);
 970    let room_id = RoomId::from_proto(request.room_id);
 971    {
 972        let room = session
 973            .db()
 974            .await
 975            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 976            .await?;
 977        room_updated(&room, &session.peer);
 978    }
 979
 980    for connection_id in session
 981        .connection_pool()
 982        .await
 983        .user_connection_ids(called_user_id)
 984    {
 985        session
 986            .peer
 987            .send(
 988                connection_id,
 989                proto::CallCanceled {
 990                    room_id: room_id.to_proto(),
 991                },
 992            )
 993            .trace_err();
 994    }
 995    response.send(proto::Ack {})?;
 996
 997    update_user_contacts(called_user_id, &session).await?;
 998    Ok(())
 999}
1000
1001async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1002    let room_id = RoomId::from_proto(message.room_id);
1003    {
1004        let room = session
1005            .db()
1006            .await
1007            .decline_call(Some(room_id), session.user_id)
1008            .await?;
1009        room_updated(&room, &session.peer);
1010    }
1011
1012    for connection_id in session
1013        .connection_pool()
1014        .await
1015        .user_connection_ids(session.user_id)
1016    {
1017        session
1018            .peer
1019            .send(
1020                connection_id,
1021                proto::CallCanceled {
1022                    room_id: room_id.to_proto(),
1023                },
1024            )
1025            .trace_err();
1026    }
1027    update_user_contacts(session.user_id, &session).await?;
1028    Ok(())
1029}
1030
1031async fn update_participant_location(
1032    request: proto::UpdateParticipantLocation,
1033    response: Response<proto::UpdateParticipantLocation>,
1034    session: Session,
1035) -> Result<()> {
1036    let room_id = RoomId::from_proto(request.room_id);
1037    let location = request
1038        .location
1039        .ok_or_else(|| anyhow!("invalid location"))?;
1040    let room = session
1041        .db()
1042        .await
1043        .update_room_participant_location(room_id, session.connection_id, location)
1044        .await?;
1045    room_updated(&room, &session.peer);
1046    response.send(proto::Ack {})?;
1047    Ok(())
1048}
1049
1050async fn share_project(
1051    request: proto::ShareProject,
1052    response: Response<proto::ShareProject>,
1053    session: Session,
1054) -> Result<()> {
1055    let (project_id, room) = &*session
1056        .db()
1057        .await
1058        .share_project(
1059            RoomId::from_proto(request.room_id),
1060            session.connection_id,
1061            &request.worktrees,
1062        )
1063        .await?;
1064    response.send(proto::ShareProjectResponse {
1065        project_id: project_id.to_proto(),
1066    })?;
1067    room_updated(&room, &session.peer);
1068
1069    Ok(())
1070}
1071
1072async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1073    let project_id = ProjectId::from_proto(message.project_id);
1074
1075    let (room, guest_connection_ids) = &*session
1076        .db()
1077        .await
1078        .unshare_project(project_id, session.connection_id)
1079        .await?;
1080
1081    broadcast(
1082        session.connection_id,
1083        guest_connection_ids.iter().copied(),
1084        |conn_id| session.peer.send(conn_id, message.clone()),
1085    );
1086    room_updated(&room, &session.peer);
1087
1088    Ok(())
1089}
1090
1091async fn join_project(
1092    request: proto::JoinProject,
1093    response: Response<proto::JoinProject>,
1094    session: Session,
1095) -> Result<()> {
1096    let project_id = ProjectId::from_proto(request.project_id);
1097    let guest_user_id = session.user_id;
1098
1099    tracing::info!(%project_id, "join project");
1100
1101    let (project, replica_id) = &mut *session
1102        .db()
1103        .await
1104        .join_project(project_id, session.connection_id)
1105        .await?;
1106
1107    let collaborators = project
1108        .collaborators
1109        .iter()
1110        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1111        .map(|collaborator| proto::Collaborator {
1112            peer_id: collaborator.connection_id as u32,
1113            replica_id: collaborator.replica_id.0 as u32,
1114            user_id: collaborator.user_id.to_proto(),
1115        })
1116        .collect::<Vec<_>>();
1117    let worktrees = project
1118        .worktrees
1119        .iter()
1120        .map(|(id, worktree)| proto::WorktreeMetadata {
1121            id: *id,
1122            root_name: worktree.root_name.clone(),
1123            visible: worktree.visible,
1124            abs_path: worktree.abs_path.clone(),
1125        })
1126        .collect::<Vec<_>>();
1127
1128    for collaborator in &collaborators {
1129        session
1130            .peer
1131            .send(
1132                ConnectionId(collaborator.peer_id),
1133                proto::AddProjectCollaborator {
1134                    project_id: project_id.to_proto(),
1135                    collaborator: Some(proto::Collaborator {
1136                        peer_id: session.connection_id.0,
1137                        replica_id: replica_id.0 as u32,
1138                        user_id: guest_user_id.to_proto(),
1139                    }),
1140                },
1141            )
1142            .trace_err();
1143    }
1144
1145    // First, we send the metadata associated with each worktree.
1146    response.send(proto::JoinProjectResponse {
1147        worktrees: worktrees.clone(),
1148        replica_id: replica_id.0 as u32,
1149        collaborators: collaborators.clone(),
1150        language_servers: project.language_servers.clone(),
1151    })?;
1152
1153    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1154        #[cfg(any(test, feature = "test-support"))]
1155        const MAX_CHUNK_SIZE: usize = 2;
1156        #[cfg(not(any(test, feature = "test-support")))]
1157        const MAX_CHUNK_SIZE: usize = 256;
1158
1159        // Stream this worktree's entries.
1160        let message = proto::UpdateWorktree {
1161            project_id: project_id.to_proto(),
1162            worktree_id,
1163            abs_path: worktree.abs_path.clone(),
1164            root_name: worktree.root_name,
1165            updated_entries: worktree.entries,
1166            removed_entries: Default::default(),
1167            scan_id: worktree.scan_id,
1168            is_last_update: worktree.is_complete,
1169        };
1170        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1171            session.peer.send(session.connection_id, update.clone())?;
1172        }
1173
1174        // Stream this worktree's diagnostics.
1175        for summary in worktree.diagnostic_summaries {
1176            session.peer.send(
1177                session.connection_id,
1178                proto::UpdateDiagnosticSummary {
1179                    project_id: project_id.to_proto(),
1180                    worktree_id: worktree.id,
1181                    summary: Some(summary),
1182                },
1183            )?;
1184        }
1185    }
1186
1187    for language_server in &project.language_servers {
1188        session.peer.send(
1189            session.connection_id,
1190            proto::UpdateLanguageServer {
1191                project_id: project_id.to_proto(),
1192                language_server_id: language_server.id,
1193                variant: Some(
1194                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1195                        proto::LspDiskBasedDiagnosticsUpdated {},
1196                    ),
1197                ),
1198            },
1199        )?;
1200    }
1201
1202    Ok(())
1203}
1204
1205async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1206    let sender_id = session.connection_id;
1207    let project_id = ProjectId::from_proto(request.project_id);
1208
1209    let project = session
1210        .db()
1211        .await
1212        .leave_project(project_id, sender_id)
1213        .await?;
1214    tracing::info!(
1215        %project_id,
1216        host_user_id = %project.host_user_id,
1217        host_connection_id = %project.host_connection_id,
1218        "leave project"
1219    );
1220    project_left(&project, &session);
1221
1222    Ok(())
1223}
1224
1225async fn update_project(
1226    request: proto::UpdateProject,
1227    response: Response<proto::UpdateProject>,
1228    session: Session,
1229) -> Result<()> {
1230    let project_id = ProjectId::from_proto(request.project_id);
1231    let (room, guest_connection_ids) = &*session
1232        .db()
1233        .await
1234        .update_project(project_id, session.connection_id, &request.worktrees)
1235        .await?;
1236    broadcast(
1237        session.connection_id,
1238        guest_connection_ids.iter().copied(),
1239        |connection_id| {
1240            session
1241                .peer
1242                .forward_send(session.connection_id, connection_id, request.clone())
1243        },
1244    );
1245    room_updated(&room, &session.peer);
1246    response.send(proto::Ack {})?;
1247
1248    Ok(())
1249}
1250
1251async fn update_worktree(
1252    request: proto::UpdateWorktree,
1253    response: Response<proto::UpdateWorktree>,
1254    session: Session,
1255) -> Result<()> {
1256    let guest_connection_ids = session
1257        .db()
1258        .await
1259        .update_worktree(&request, session.connection_id)
1260        .await?;
1261
1262    broadcast(
1263        session.connection_id,
1264        guest_connection_ids.iter().copied(),
1265        |connection_id| {
1266            session
1267                .peer
1268                .forward_send(session.connection_id, connection_id, request.clone())
1269        },
1270    );
1271    response.send(proto::Ack {})?;
1272    Ok(())
1273}
1274
1275async fn update_diagnostic_summary(
1276    message: proto::UpdateDiagnosticSummary,
1277    session: Session,
1278) -> Result<()> {
1279    let guest_connection_ids = session
1280        .db()
1281        .await
1282        .update_diagnostic_summary(&message, session.connection_id)
1283        .await?;
1284
1285    broadcast(
1286        session.connection_id,
1287        guest_connection_ids.iter().copied(),
1288        |connection_id| {
1289            session
1290                .peer
1291                .forward_send(session.connection_id, connection_id, message.clone())
1292        },
1293    );
1294
1295    Ok(())
1296}
1297
1298async fn start_language_server(
1299    request: proto::StartLanguageServer,
1300    session: Session,
1301) -> Result<()> {
1302    let guest_connection_ids = session
1303        .db()
1304        .await
1305        .start_language_server(&request, session.connection_id)
1306        .await?;
1307
1308    broadcast(
1309        session.connection_id,
1310        guest_connection_ids.iter().copied(),
1311        |connection_id| {
1312            session
1313                .peer
1314                .forward_send(session.connection_id, connection_id, request.clone())
1315        },
1316    );
1317    Ok(())
1318}
1319
1320async fn update_language_server(
1321    request: proto::UpdateLanguageServer,
1322    session: Session,
1323) -> Result<()> {
1324    let project_id = ProjectId::from_proto(request.project_id);
1325    let project_connection_ids = session
1326        .db()
1327        .await
1328        .project_connection_ids(project_id, session.connection_id)
1329        .await?;
1330    broadcast(
1331        session.connection_id,
1332        project_connection_ids.iter().copied(),
1333        |connection_id| {
1334            session
1335                .peer
1336                .forward_send(session.connection_id, connection_id, request.clone())
1337        },
1338    );
1339    Ok(())
1340}
1341
1342async fn forward_project_request<T>(
1343    request: T,
1344    response: Response<T>,
1345    session: Session,
1346) -> Result<()>
1347where
1348    T: EntityMessage + RequestMessage,
1349{
1350    let project_id = ProjectId::from_proto(request.remote_entity_id());
1351    let host_connection_id = {
1352        let collaborators = session
1353            .db()
1354            .await
1355            .project_collaborators(project_id, session.connection_id)
1356            .await?;
1357        ConnectionId(
1358            collaborators
1359                .iter()
1360                .find(|collaborator| collaborator.is_host)
1361                .ok_or_else(|| anyhow!("host not found"))?
1362                .connection_id as u32,
1363        )
1364    };
1365
1366    let payload = session
1367        .peer
1368        .forward_request(session.connection_id, host_connection_id, request)
1369        .await?;
1370
1371    response.send(payload)?;
1372    Ok(())
1373}
1374
1375async fn save_buffer(
1376    request: proto::SaveBuffer,
1377    response: Response<proto::SaveBuffer>,
1378    session: Session,
1379) -> Result<()> {
1380    let project_id = ProjectId::from_proto(request.project_id);
1381    let host_connection_id = {
1382        let collaborators = session
1383            .db()
1384            .await
1385            .project_collaborators(project_id, session.connection_id)
1386            .await?;
1387        let host = collaborators
1388            .iter()
1389            .find(|collaborator| collaborator.is_host)
1390            .ok_or_else(|| anyhow!("host not found"))?;
1391        ConnectionId(host.connection_id as u32)
1392    };
1393    let response_payload = session
1394        .peer
1395        .forward_request(session.connection_id, host_connection_id, request.clone())
1396        .await?;
1397
1398    let mut collaborators = session
1399        .db()
1400        .await
1401        .project_collaborators(project_id, session.connection_id)
1402        .await?;
1403    collaborators
1404        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1405    let project_connection_ids = collaborators
1406        .iter()
1407        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1408    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1409        session
1410            .peer
1411            .forward_send(host_connection_id, conn_id, response_payload.clone())
1412    });
1413    response.send(response_payload)?;
1414    Ok(())
1415}
1416
1417async fn create_buffer_for_peer(
1418    request: proto::CreateBufferForPeer,
1419    session: Session,
1420) -> Result<()> {
1421    session.peer.forward_send(
1422        session.connection_id,
1423        ConnectionId(request.peer_id),
1424        request,
1425    )?;
1426    Ok(())
1427}
1428
1429async fn update_buffer(
1430    request: proto::UpdateBuffer,
1431    response: Response<proto::UpdateBuffer>,
1432    session: Session,
1433) -> Result<()> {
1434    let project_id = ProjectId::from_proto(request.project_id);
1435    let project_connection_ids = session
1436        .db()
1437        .await
1438        .project_connection_ids(project_id, session.connection_id)
1439        .await?;
1440
1441    broadcast(
1442        session.connection_id,
1443        project_connection_ids.iter().copied(),
1444        |connection_id| {
1445            session
1446                .peer
1447                .forward_send(session.connection_id, connection_id, request.clone())
1448        },
1449    );
1450    response.send(proto::Ack {})?;
1451    Ok(())
1452}
1453
1454async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1455    let project_id = ProjectId::from_proto(request.project_id);
1456    let project_connection_ids = session
1457        .db()
1458        .await
1459        .project_connection_ids(project_id, session.connection_id)
1460        .await?;
1461
1462    broadcast(
1463        session.connection_id,
1464        project_connection_ids.iter().copied(),
1465        |connection_id| {
1466            session
1467                .peer
1468                .forward_send(session.connection_id, connection_id, request.clone())
1469        },
1470    );
1471    Ok(())
1472}
1473
1474async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1475    let project_id = ProjectId::from_proto(request.project_id);
1476    let project_connection_ids = session
1477        .db()
1478        .await
1479        .project_connection_ids(project_id, session.connection_id)
1480        .await?;
1481    broadcast(
1482        session.connection_id,
1483        project_connection_ids.iter().copied(),
1484        |connection_id| {
1485            session
1486                .peer
1487                .forward_send(session.connection_id, connection_id, request.clone())
1488        },
1489    );
1490    Ok(())
1491}
1492
1493async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1494    let project_id = ProjectId::from_proto(request.project_id);
1495    let project_connection_ids = session
1496        .db()
1497        .await
1498        .project_connection_ids(project_id, session.connection_id)
1499        .await?;
1500    broadcast(
1501        session.connection_id,
1502        project_connection_ids.iter().copied(),
1503        |connection_id| {
1504            session
1505                .peer
1506                .forward_send(session.connection_id, connection_id, request.clone())
1507        },
1508    );
1509    Ok(())
1510}
1511
1512async fn follow(
1513    request: proto::Follow,
1514    response: Response<proto::Follow>,
1515    session: Session,
1516) -> Result<()> {
1517    let project_id = ProjectId::from_proto(request.project_id);
1518    let leader_id = ConnectionId(request.leader_id);
1519    let follower_id = session.connection_id;
1520    {
1521        let project_connection_ids = session
1522            .db()
1523            .await
1524            .project_connection_ids(project_id, session.connection_id)
1525            .await?;
1526
1527        if !project_connection_ids.contains(&leader_id) {
1528            Err(anyhow!("no such peer"))?;
1529        }
1530    }
1531
1532    let mut response_payload = session
1533        .peer
1534        .forward_request(session.connection_id, leader_id, request)
1535        .await?;
1536    response_payload
1537        .views
1538        .retain(|view| view.leader_id != Some(follower_id.0));
1539    response.send(response_payload)?;
1540    Ok(())
1541}
1542
1543async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1544    let project_id = ProjectId::from_proto(request.project_id);
1545    let leader_id = ConnectionId(request.leader_id);
1546    let project_connection_ids = session
1547        .db()
1548        .await
1549        .project_connection_ids(project_id, session.connection_id)
1550        .await?;
1551    if !project_connection_ids.contains(&leader_id) {
1552        Err(anyhow!("no such peer"))?;
1553    }
1554    session
1555        .peer
1556        .forward_send(session.connection_id, leader_id, request)?;
1557    Ok(())
1558}
1559
1560async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1561    let project_id = ProjectId::from_proto(request.project_id);
1562    let project_connection_ids = session
1563        .db
1564        .lock()
1565        .await
1566        .project_connection_ids(project_id, session.connection_id)
1567        .await?;
1568
1569    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1570        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1571        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1572        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1573    });
1574    for follower_id in &request.follower_ids {
1575        let follower_id = ConnectionId(*follower_id);
1576        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1577            session
1578                .peer
1579                .forward_send(session.connection_id, follower_id, request.clone())?;
1580        }
1581    }
1582    Ok(())
1583}
1584
1585async fn get_users(
1586    request: proto::GetUsers,
1587    response: Response<proto::GetUsers>,
1588    session: Session,
1589) -> Result<()> {
1590    let user_ids = request
1591        .user_ids
1592        .into_iter()
1593        .map(UserId::from_proto)
1594        .collect();
1595    let users = session
1596        .db()
1597        .await
1598        .get_users_by_ids(user_ids)
1599        .await?
1600        .into_iter()
1601        .map(|user| proto::User {
1602            id: user.id.to_proto(),
1603            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1604            github_login: user.github_login,
1605        })
1606        .collect();
1607    response.send(proto::UsersResponse { users })?;
1608    Ok(())
1609}
1610
1611async fn fuzzy_search_users(
1612    request: proto::FuzzySearchUsers,
1613    response: Response<proto::FuzzySearchUsers>,
1614    session: Session,
1615) -> Result<()> {
1616    let query = request.query;
1617    let users = match query.len() {
1618        0 => vec![],
1619        1 | 2 => session
1620            .db()
1621            .await
1622            .get_user_by_github_account(&query, None)
1623            .await?
1624            .into_iter()
1625            .collect(),
1626        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1627    };
1628    let users = users
1629        .into_iter()
1630        .filter(|user| user.id != session.user_id)
1631        .map(|user| proto::User {
1632            id: user.id.to_proto(),
1633            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1634            github_login: user.github_login,
1635        })
1636        .collect();
1637    response.send(proto::UsersResponse { users })?;
1638    Ok(())
1639}
1640
1641async fn request_contact(
1642    request: proto::RequestContact,
1643    response: Response<proto::RequestContact>,
1644    session: Session,
1645) -> Result<()> {
1646    let requester_id = session.user_id;
1647    let responder_id = UserId::from_proto(request.responder_id);
1648    if requester_id == responder_id {
1649        return Err(anyhow!("cannot add yourself as a contact"))?;
1650    }
1651
1652    session
1653        .db()
1654        .await
1655        .send_contact_request(requester_id, responder_id)
1656        .await?;
1657
1658    // Update outgoing contact requests of requester
1659    let mut update = proto::UpdateContacts::default();
1660    update.outgoing_requests.push(responder_id.to_proto());
1661    for connection_id in session
1662        .connection_pool()
1663        .await
1664        .user_connection_ids(requester_id)
1665    {
1666        session.peer.send(connection_id, update.clone())?;
1667    }
1668
1669    // Update incoming contact requests of responder
1670    let mut update = proto::UpdateContacts::default();
1671    update
1672        .incoming_requests
1673        .push(proto::IncomingContactRequest {
1674            requester_id: requester_id.to_proto(),
1675            should_notify: true,
1676        });
1677    for connection_id in session
1678        .connection_pool()
1679        .await
1680        .user_connection_ids(responder_id)
1681    {
1682        session.peer.send(connection_id, update.clone())?;
1683    }
1684
1685    response.send(proto::Ack {})?;
1686    Ok(())
1687}
1688
1689async fn respond_to_contact_request(
1690    request: proto::RespondToContactRequest,
1691    response: Response<proto::RespondToContactRequest>,
1692    session: Session,
1693) -> Result<()> {
1694    let responder_id = session.user_id;
1695    let requester_id = UserId::from_proto(request.requester_id);
1696    let db = session.db().await;
1697    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1698        db.dismiss_contact_notification(responder_id, requester_id)
1699            .await?;
1700    } else {
1701        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1702
1703        db.respond_to_contact_request(responder_id, requester_id, accept)
1704            .await?;
1705        let requester_busy = db.is_user_busy(requester_id).await?;
1706        let responder_busy = db.is_user_busy(responder_id).await?;
1707
1708        let pool = session.connection_pool().await;
1709        // Update responder with new contact
1710        let mut update = proto::UpdateContacts::default();
1711        if accept {
1712            update
1713                .contacts
1714                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1715        }
1716        update
1717            .remove_incoming_requests
1718            .push(requester_id.to_proto());
1719        for connection_id in pool.user_connection_ids(responder_id) {
1720            session.peer.send(connection_id, update.clone())?;
1721        }
1722
1723        // Update requester with new contact
1724        let mut update = proto::UpdateContacts::default();
1725        if accept {
1726            update
1727                .contacts
1728                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1729        }
1730        update
1731            .remove_outgoing_requests
1732            .push(responder_id.to_proto());
1733        for connection_id in pool.user_connection_ids(requester_id) {
1734            session.peer.send(connection_id, update.clone())?;
1735        }
1736    }
1737
1738    response.send(proto::Ack {})?;
1739    Ok(())
1740}
1741
1742async fn remove_contact(
1743    request: proto::RemoveContact,
1744    response: Response<proto::RemoveContact>,
1745    session: Session,
1746) -> Result<()> {
1747    let requester_id = session.user_id;
1748    let responder_id = UserId::from_proto(request.user_id);
1749    let db = session.db().await;
1750    db.remove_contact(requester_id, responder_id).await?;
1751
1752    let pool = session.connection_pool().await;
1753    // Update outgoing contact requests of requester
1754    let mut update = proto::UpdateContacts::default();
1755    update
1756        .remove_outgoing_requests
1757        .push(responder_id.to_proto());
1758    for connection_id in pool.user_connection_ids(requester_id) {
1759        session.peer.send(connection_id, update.clone())?;
1760    }
1761
1762    // Update incoming contact requests of responder
1763    let mut update = proto::UpdateContacts::default();
1764    update
1765        .remove_incoming_requests
1766        .push(requester_id.to_proto());
1767    for connection_id in pool.user_connection_ids(responder_id) {
1768        session.peer.send(connection_id, update.clone())?;
1769    }
1770
1771    response.send(proto::Ack {})?;
1772    Ok(())
1773}
1774
1775async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1776    let project_id = ProjectId::from_proto(request.project_id);
1777    let project_connection_ids = session
1778        .db()
1779        .await
1780        .project_connection_ids(project_id, session.connection_id)
1781        .await?;
1782    broadcast(
1783        session.connection_id,
1784        project_connection_ids.iter().copied(),
1785        |connection_id| {
1786            session
1787                .peer
1788                .forward_send(session.connection_id, connection_id, request.clone())
1789        },
1790    );
1791    Ok(())
1792}
1793
1794async fn get_private_user_info(
1795    _request: proto::GetPrivateUserInfo,
1796    response: Response<proto::GetPrivateUserInfo>,
1797    session: Session,
1798) -> Result<()> {
1799    let metrics_id = session
1800        .db()
1801        .await
1802        .get_user_metrics_id(session.user_id)
1803        .await?;
1804    let user = session
1805        .db()
1806        .await
1807        .get_user_by_id(session.user_id)
1808        .await?
1809        .ok_or_else(|| anyhow!("user not found"))?;
1810    response.send(proto::GetPrivateUserInfoResponse {
1811        metrics_id,
1812        staff: user.admin,
1813    })?;
1814    Ok(())
1815}
1816
1817fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1818    match message {
1819        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1820        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1821        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1822        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1823        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1824            code: frame.code.into(),
1825            reason: frame.reason,
1826        })),
1827    }
1828}
1829
1830fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1831    match message {
1832        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1833        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1834        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1835        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1836        AxumMessage::Close(frame) => {
1837            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1838                code: frame.code.into(),
1839                reason: frame.reason,
1840            }))
1841        }
1842    }
1843}
1844
1845fn build_initial_contacts_update(
1846    contacts: Vec<db::Contact>,
1847    pool: &ConnectionPool,
1848) -> proto::UpdateContacts {
1849    let mut update = proto::UpdateContacts::default();
1850
1851    for contact in contacts {
1852        match contact {
1853            db::Contact::Accepted {
1854                user_id,
1855                should_notify,
1856                busy,
1857            } => {
1858                update
1859                    .contacts
1860                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1861            }
1862            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1863            db::Contact::Incoming {
1864                user_id,
1865                should_notify,
1866            } => update
1867                .incoming_requests
1868                .push(proto::IncomingContactRequest {
1869                    requester_id: user_id.to_proto(),
1870                    should_notify,
1871                }),
1872        }
1873    }
1874
1875    update
1876}
1877
1878fn contact_for_user(
1879    user_id: UserId,
1880    should_notify: bool,
1881    busy: bool,
1882    pool: &ConnectionPool,
1883) -> proto::Contact {
1884    proto::Contact {
1885        user_id: user_id.to_proto(),
1886        online: pool.is_user_online(user_id),
1887        busy,
1888        should_notify,
1889    }
1890}
1891
1892fn room_updated(room: &proto::Room, peer: &Peer) {
1893    for participant in &room.participants {
1894        peer.send(
1895            ConnectionId(participant.peer_id),
1896            proto::RoomUpdated {
1897                room: Some(room.clone()),
1898            },
1899        )
1900        .trace_err();
1901    }
1902}
1903
1904async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1905    let db = session.db().await;
1906    let contacts = db.get_contacts(user_id).await?;
1907    let busy = db.is_user_busy(user_id).await?;
1908
1909    let pool = session.connection_pool().await;
1910    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1911    for contact in contacts {
1912        if let db::Contact::Accepted {
1913            user_id: contact_user_id,
1914            ..
1915        } = contact
1916        {
1917            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1918                session
1919                    .peer
1920                    .send(
1921                        contact_conn_id,
1922                        proto::UpdateContacts {
1923                            contacts: vec![updated_contact.clone()],
1924                            remove_contacts: Default::default(),
1925                            incoming_requests: Default::default(),
1926                            remove_incoming_requests: Default::default(),
1927                            outgoing_requests: Default::default(),
1928                            remove_outgoing_requests: Default::default(),
1929                        },
1930                    )
1931                    .trace_err();
1932            }
1933        }
1934    }
1935    Ok(())
1936}
1937
1938async fn leave_room_for_session(session: &Session) -> Result<()> {
1939    let mut contacts_to_update = HashSet::default();
1940
1941    let room_id;
1942    let canceled_calls_to_user_ids;
1943    let live_kit_room;
1944    let delete_live_kit_room;
1945    {
1946        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1947        contacts_to_update.insert(session.user_id);
1948
1949        for project in left_room.left_projects.values() {
1950            project_left(project, session);
1951        }
1952
1953        room_updated(&left_room.room, &session.peer);
1954        room_id = RoomId::from_proto(left_room.room.id);
1955        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1956        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1957        delete_live_kit_room = left_room.room.participants.is_empty();
1958    }
1959
1960    {
1961        let pool = session.connection_pool().await;
1962        for canceled_user_id in canceled_calls_to_user_ids {
1963            for connection_id in pool.user_connection_ids(canceled_user_id) {
1964                session
1965                    .peer
1966                    .send(
1967                        connection_id,
1968                        proto::CallCanceled {
1969                            room_id: room_id.to_proto(),
1970                        },
1971                    )
1972                    .trace_err();
1973            }
1974            contacts_to_update.insert(canceled_user_id);
1975        }
1976    }
1977
1978    for contact_user_id in contacts_to_update {
1979        update_user_contacts(contact_user_id, &session).await?;
1980    }
1981
1982    if let Some(live_kit) = session.live_kit_client.as_ref() {
1983        live_kit
1984            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1985            .await
1986            .trace_err();
1987
1988        if delete_live_kit_room {
1989            live_kit.delete_room(live_kit_room).await.trace_err();
1990        }
1991    }
1992
1993    Ok(())
1994}
1995
1996fn project_left(project: &db::LeftProject, session: &Session) {
1997    for connection_id in &project.connection_ids {
1998        if project.host_user_id == session.user_id {
1999            session
2000                .peer
2001                .send(
2002                    *connection_id,
2003                    proto::UnshareProject {
2004                        project_id: project.id.to_proto(),
2005                    },
2006                )
2007                .trace_err();
2008        } else {
2009            session
2010                .peer
2011                .send(
2012                    *connection_id,
2013                    proto::RemoveProjectCollaborator {
2014                        project_id: project.id.to_proto(),
2015                        peer_id: session.connection_id.0,
2016                    },
2017                )
2018                .trace_err();
2019        }
2020    }
2021
2022    session
2023        .peer
2024        .send(
2025            session.connection_id,
2026            proto::UnshareProject {
2027                project_id: project.id.to_proto(),
2028            },
2029        )
2030        .trace_err();
2031}
2032
2033pub trait ResultExt {
2034    type Ok;
2035
2036    fn trace_err(self) -> Option<Self::Ok>;
2037}
2038
2039impl<T, E> ResultExt for Result<T, E>
2040where
2041    E: std::fmt::Debug,
2042{
2043    type Ok = T;
2044
2045    fn trace_err(self) -> Option<T> {
2046        match self {
2047            Ok(value) => Some(value),
2048            Err(error) => {
2049                tracing::error!("{:?}", error);
2050                None
2051            }
2052        }
2053    }
2054}