rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(20);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    peer: Arc<Peer>,
 142    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    executor: Executor,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146    teardown: watch::Sender<()>,
 147}
 148
 149pub(crate) struct ConnectionPoolGuard<'a> {
 150    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 151    _not_send: PhantomData<Rc<()>>,
 152}
 153
 154#[derive(Serialize)]
 155pub struct ServerSnapshot<'a> {
 156    peer: &'a Peer,
 157    #[serde(serialize_with = "serialize_deref")]
 158    connection_pool: ConnectionPoolGuard<'a>,
 159}
 160
 161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 162where
 163    S: Serializer,
 164    T: Deref<Target = U>,
 165    U: Serialize,
 166{
 167    Serialize::serialize(value.deref(), serializer)
 168}
 169
 170impl Server {
 171    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 172        let mut server = Self {
 173            peer: Peer::new(),
 174            app_state,
 175            executor,
 176            connection_pool: Default::default(),
 177            handlers: Default::default(),
 178            teardown: watch::channel(()).0,
 179        };
 180
 181        server
 182            .add_request_handler(ping)
 183            .add_request_handler(create_room)
 184            .add_request_handler(join_room)
 185            .add_message_handler(leave_room)
 186            .add_request_handler(call)
 187            .add_request_handler(cancel_call)
 188            .add_message_handler(decline_call)
 189            .add_request_handler(update_participant_location)
 190            .add_request_handler(share_project)
 191            .add_message_handler(unshare_project)
 192            .add_request_handler(join_project)
 193            .add_message_handler(leave_project)
 194            .add_request_handler(update_project)
 195            .add_request_handler(update_worktree)
 196            .add_message_handler(start_language_server)
 197            .add_message_handler(update_language_server)
 198            .add_message_handler(update_diagnostic_summary)
 199            .add_request_handler(forward_project_request::<proto::GetHover>)
 200            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 201            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 202            .add_request_handler(forward_project_request::<proto::GetReferences>)
 203            .add_request_handler(forward_project_request::<proto::SearchProject>)
 204            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 205            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 206            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 209            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 210            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 211            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 213            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 214            .add_request_handler(forward_project_request::<proto::PerformRename>)
 215            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 216            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 217            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 218            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 221            .add_message_handler(create_buffer_for_peer)
 222            .add_request_handler(update_buffer)
 223            .add_message_handler(update_buffer_file)
 224            .add_message_handler(buffer_reloaded)
 225            .add_message_handler(buffer_saved)
 226            .add_request_handler(save_buffer)
 227            .add_request_handler(get_users)
 228            .add_request_handler(fuzzy_search_users)
 229            .add_request_handler(request_contact)
 230            .add_request_handler(remove_contact)
 231            .add_request_handler(respond_to_contact_request)
 232            .add_request_handler(follow)
 233            .add_message_handler(unfollow)
 234            .add_message_handler(update_followers)
 235            .add_message_handler(update_diff_base)
 236            .add_request_handler(get_private_user_info);
 237
 238        Arc::new(server)
 239    }
 240
 241    pub async fn start(&self) -> Result<()> {
 242        self.app_state.db.delete_stale_projects().await?;
 243        let db = self.app_state.db.clone();
 244        let peer = self.peer.clone();
 245        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 246        let pool = self.connection_pool.clone();
 247        let live_kit_client = self.app_state.live_kit_client.clone();
 248        self.executor.spawn_detached(async move {
 249            timeout.await;
 250            if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
 251                for room_id in room_ids {
 252                    let mut contacts_to_update = HashSet::default();
 253                    let mut canceled_calls_to_user_ids = Vec::new();
 254                    let mut live_kit_room = String::new();
 255                    let mut delete_live_kit_room = false;
 256
 257                    if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 258                        room_updated(&refreshed_room.room, &peer);
 259                        contacts_to_update
 260                            .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 261                        contacts_to_update
 262                            .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 263                        canceled_calls_to_user_ids =
 264                            mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 265                        live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 266                        delete_live_kit_room = refreshed_room.room.participants.is_empty();
 267                    }
 268
 269                    {
 270                        let pool = pool.lock();
 271                        for canceled_user_id in canceled_calls_to_user_ids {
 272                            for connection_id in pool.user_connection_ids(canceled_user_id) {
 273                                peer.send(
 274                                    connection_id,
 275                                    proto::CallCanceled {
 276                                        room_id: room_id.to_proto(),
 277                                    },
 278                                )
 279                                .trace_err();
 280                            }
 281                        }
 282                    }
 283
 284                    for user_id in contacts_to_update {
 285                        let busy = db.is_user_busy(user_id).await.trace_err();
 286                        let contacts = db.get_contacts(user_id).await.trace_err();
 287                        if let Some((busy, contacts)) = busy.zip(contacts) {
 288                            let pool = pool.lock();
 289                            let updated_contact = contact_for_user(user_id, false, busy, &pool);
 290                            for contact in contacts {
 291                                if let db::Contact::Accepted {
 292                                    user_id: contact_user_id,
 293                                    ..
 294                                } = contact
 295                                {
 296                                    for contact_conn_id in pool.user_connection_ids(contact_user_id)
 297                                    {
 298                                        peer.send(
 299                                            contact_conn_id,
 300                                            proto::UpdateContacts {
 301                                                contacts: vec![updated_contact.clone()],
 302                                                remove_contacts: Default::default(),
 303                                                incoming_requests: Default::default(),
 304                                                remove_incoming_requests: Default::default(),
 305                                                outgoing_requests: Default::default(),
 306                                                remove_outgoing_requests: Default::default(),
 307                                            },
 308                                        )
 309                                        .trace_err();
 310                                    }
 311                                }
 312                            }
 313                        }
 314                    }
 315
 316                    if let Some(live_kit) = live_kit_client.as_ref() {
 317                        if delete_live_kit_room {
 318                            live_kit.delete_room(live_kit_room).await.trace_err();
 319                        }
 320                    }
 321                }
 322            }
 323        });
 324        Ok(())
 325    }
 326
 327    pub fn teardown(&self) {
 328        self.peer.reset();
 329        self.connection_pool.lock().reset();
 330        let _ = self.teardown.send(());
 331    }
 332
 333    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 334    where
 335        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 336        Fut: 'static + Send + Future<Output = Result<()>>,
 337        M: EnvelopedMessage,
 338    {
 339        let prev_handler = self.handlers.insert(
 340            TypeId::of::<M>(),
 341            Box::new(move |envelope, session| {
 342                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 343                let span = info_span!(
 344                    "handle message",
 345                    payload_type = envelope.payload_type_name()
 346                );
 347                span.in_scope(|| {
 348                    tracing::info!(
 349                        payload_type = envelope.payload_type_name(),
 350                        "message received"
 351                    );
 352                });
 353                let future = (handler)(*envelope, session);
 354                async move {
 355                    if let Err(error) = future.await {
 356                        tracing::error!(%error, "error handling message");
 357                    }
 358                }
 359                .instrument(span)
 360                .boxed()
 361            }),
 362        );
 363        if prev_handler.is_some() {
 364            panic!("registered a handler for the same message twice");
 365        }
 366        self
 367    }
 368
 369    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 370    where
 371        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 372        Fut: 'static + Send + Future<Output = Result<()>>,
 373        M: EnvelopedMessage,
 374    {
 375        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 376        self
 377    }
 378
 379    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 380    where
 381        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 382        Fut: Send + Future<Output = Result<()>>,
 383        M: RequestMessage,
 384    {
 385        let handler = Arc::new(handler);
 386        self.add_handler(move |envelope, session| {
 387            let receipt = envelope.receipt();
 388            let handler = handler.clone();
 389            async move {
 390                let peer = session.peer.clone();
 391                let responded = Arc::new(AtomicBool::default());
 392                let response = Response {
 393                    peer: peer.clone(),
 394                    responded: responded.clone(),
 395                    receipt,
 396                };
 397                match (handler)(envelope.payload, response, session).await {
 398                    Ok(()) => {
 399                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 400                            Ok(())
 401                        } else {
 402                            Err(anyhow!("handler did not send a response"))?
 403                        }
 404                    }
 405                    Err(error) => {
 406                        peer.respond_with_error(
 407                            receipt,
 408                            proto::Error {
 409                                message: error.to_string(),
 410                            },
 411                        )?;
 412                        Err(error)
 413                    }
 414                }
 415            }
 416        })
 417    }
 418
 419    pub fn handle_connection(
 420        self: &Arc<Self>,
 421        connection: Connection,
 422        address: String,
 423        user: User,
 424        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 425        executor: Executor,
 426    ) -> impl Future<Output = Result<()>> {
 427        let this = self.clone();
 428        let user_id = user.id;
 429        let login = user.github_login;
 430        let span = info_span!("handle connection", %user_id, %login, %address);
 431        let mut teardown = self.teardown.subscribe();
 432        async move {
 433            let (connection_id, handle_io, mut incoming_rx) = this
 434                .peer
 435                .add_connection(connection, {
 436                    let executor = executor.clone();
 437                    move |duration| executor.sleep(duration)
 438                });
 439
 440            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 441            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 442            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 443
 444            if let Some(send_connection_id) = send_connection_id.take() {
 445                let _ = send_connection_id.send(connection_id);
 446            }
 447
 448            if !user.connected_once {
 449                this.peer.send(connection_id, proto::ShowContacts {})?;
 450                this.app_state.db.set_user_connected_once(user_id, true).await?;
 451            }
 452
 453            let (contacts, invite_code) = future::try_join(
 454                this.app_state.db.get_contacts(user_id),
 455                this.app_state.db.get_invite_code_for_user(user_id)
 456            ).await?;
 457
 458            {
 459                let mut pool = this.connection_pool.lock();
 460                pool.add_connection(connection_id, user_id, user.admin);
 461                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 462
 463                if let Some((code, count)) = invite_code {
 464                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 465                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 466                        count: count as u32,
 467                    })?;
 468                }
 469            }
 470
 471            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 472                this.peer.send(connection_id, incoming_call)?;
 473            }
 474
 475            let session = Session {
 476                user_id,
 477                connection_id,
 478                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 479                peer: this.peer.clone(),
 480                connection_pool: this.connection_pool.clone(),
 481                live_kit_client: this.app_state.live_kit_client.clone()
 482            };
 483            update_user_contacts(user_id, &session).await?;
 484
 485            let handle_io = handle_io.fuse();
 486            futures::pin_mut!(handle_io);
 487
 488            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 489            // This prevents deadlocks when e.g., client A performs a request to client B and
 490            // client B performs a request to client A. If both clients stop processing further
 491            // messages until their respective request completes, they won't have a chance to
 492            // respond to the other client's request and cause a deadlock.
 493            //
 494            // This arrangement ensures we will attempt to process earlier messages first, but fall
 495            // back to processing messages arrived later in the spirit of making progress.
 496            let mut foreground_message_handlers = FuturesUnordered::new();
 497            loop {
 498                let next_message = incoming_rx.next().fuse();
 499                futures::pin_mut!(next_message);
 500                futures::select_biased! {
 501                    _ = teardown.changed().fuse() => return Ok(()),
 502                    result = handle_io => {
 503                        if let Err(error) = result {
 504                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 505                        }
 506                        break;
 507                    }
 508                    _ = foreground_message_handlers.next() => {}
 509                    message = next_message => {
 510                        if let Some(message) = message {
 511                            let type_name = message.payload_type_name();
 512                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 513                            let span_enter = span.enter();
 514                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 515                                let is_background = message.is_background();
 516                                let handle_message = (handler)(message, session.clone());
 517                                drop(span_enter);
 518
 519                                let handle_message = handle_message.instrument(span);
 520                                if is_background {
 521                                    executor.spawn_detached(handle_message);
 522                                } else {
 523                                    foreground_message_handlers.push(handle_message);
 524                                }
 525                            } else {
 526                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 527                            }
 528                        } else {
 529                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 530                            break;
 531                        }
 532                    }
 533                }
 534            }
 535
 536            drop(foreground_message_handlers);
 537            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 538            if let Err(error) = sign_out(session, teardown, executor).await {
 539                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 540            }
 541
 542            Ok(())
 543        }.instrument(span)
 544    }
 545
 546    pub async fn invite_code_redeemed(
 547        self: &Arc<Self>,
 548        inviter_id: UserId,
 549        invitee_id: UserId,
 550    ) -> Result<()> {
 551        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 552            if let Some(code) = &user.invite_code {
 553                let pool = self.connection_pool.lock();
 554                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 555                for connection_id in pool.user_connection_ids(inviter_id) {
 556                    self.peer.send(
 557                        connection_id,
 558                        proto::UpdateContacts {
 559                            contacts: vec![invitee_contact.clone()],
 560                            ..Default::default()
 561                        },
 562                    )?;
 563                    self.peer.send(
 564                        connection_id,
 565                        proto::UpdateInviteInfo {
 566                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 567                            count: user.invite_count as u32,
 568                        },
 569                    )?;
 570                }
 571            }
 572        }
 573        Ok(())
 574    }
 575
 576    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 577        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 578            if let Some(invite_code) = &user.invite_code {
 579                let pool = self.connection_pool.lock();
 580                for connection_id in pool.user_connection_ids(user_id) {
 581                    self.peer.send(
 582                        connection_id,
 583                        proto::UpdateInviteInfo {
 584                            url: format!(
 585                                "{}{}",
 586                                self.app_state.config.invite_link_prefix, invite_code
 587                            ),
 588                            count: user.invite_count as u32,
 589                        },
 590                    )?;
 591                }
 592            }
 593        }
 594        Ok(())
 595    }
 596
 597    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 598        ServerSnapshot {
 599            connection_pool: ConnectionPoolGuard {
 600                guard: self.connection_pool.lock(),
 601                _not_send: PhantomData,
 602            },
 603            peer: &self.peer,
 604        }
 605    }
 606}
 607
 608impl<'a> Deref for ConnectionPoolGuard<'a> {
 609    type Target = ConnectionPool;
 610
 611    fn deref(&self) -> &Self::Target {
 612        &*self.guard
 613    }
 614}
 615
 616impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 617    fn deref_mut(&mut self) -> &mut Self::Target {
 618        &mut *self.guard
 619    }
 620}
 621
 622impl<'a> Drop for ConnectionPoolGuard<'a> {
 623    fn drop(&mut self) {
 624        #[cfg(test)]
 625        self.check_invariants();
 626    }
 627}
 628
 629fn broadcast<F>(
 630    sender_id: ConnectionId,
 631    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 632    mut f: F,
 633) where
 634    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 635{
 636    for receiver_id in receiver_ids {
 637        if receiver_id != sender_id {
 638            f(receiver_id).trace_err();
 639        }
 640    }
 641}
 642
 643lazy_static! {
 644    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 645}
 646
 647pub struct ProtocolVersion(u32);
 648
 649impl Header for ProtocolVersion {
 650    fn name() -> &'static HeaderName {
 651        &ZED_PROTOCOL_VERSION
 652    }
 653
 654    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 655    where
 656        Self: Sized,
 657        I: Iterator<Item = &'i axum::http::HeaderValue>,
 658    {
 659        let version = values
 660            .next()
 661            .ok_or_else(axum::headers::Error::invalid)?
 662            .to_str()
 663            .map_err(|_| axum::headers::Error::invalid())?
 664            .parse()
 665            .map_err(|_| axum::headers::Error::invalid())?;
 666        Ok(Self(version))
 667    }
 668
 669    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 670        values.extend([self.0.to_string().parse().unwrap()]);
 671    }
 672}
 673
 674pub fn routes(server: Arc<Server>) -> Router<Body> {
 675    Router::new()
 676        .route("/rpc", get(handle_websocket_request))
 677        .layer(
 678            ServiceBuilder::new()
 679                .layer(Extension(server.app_state.clone()))
 680                .layer(middleware::from_fn(auth::validate_header)),
 681        )
 682        .route("/metrics", get(handle_metrics))
 683        .layer(Extension(server))
 684}
 685
 686pub async fn handle_websocket_request(
 687    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 688    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 689    Extension(server): Extension<Arc<Server>>,
 690    Extension(user): Extension<User>,
 691    ws: WebSocketUpgrade,
 692) -> axum::response::Response {
 693    if protocol_version != rpc::PROTOCOL_VERSION {
 694        return (
 695            StatusCode::UPGRADE_REQUIRED,
 696            "client must be upgraded".to_string(),
 697        )
 698            .into_response();
 699    }
 700    let socket_address = socket_address.to_string();
 701    ws.on_upgrade(move |socket| {
 702        use util::ResultExt;
 703        let socket = socket
 704            .map_ok(to_tungstenite_message)
 705            .err_into()
 706            .with(|message| async move { Ok(to_axum_message(message)) });
 707        let connection = Connection::new(Box::pin(socket));
 708        async move {
 709            server
 710                .handle_connection(connection, socket_address, user, None, Executor::Production)
 711                .await
 712                .log_err();
 713        }
 714    })
 715}
 716
 717pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 718    let connections = server
 719        .connection_pool
 720        .lock()
 721        .connections()
 722        .filter(|connection| !connection.admin)
 723        .count();
 724
 725    METRIC_CONNECTIONS.set(connections as _);
 726
 727    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 728    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 729
 730    let encoder = prometheus::TextEncoder::new();
 731    let metric_families = prometheus::gather();
 732    let encoded_metrics = encoder
 733        .encode_to_string(&metric_families)
 734        .map_err(|err| anyhow!("{}", err))?;
 735    Ok(encoded_metrics)
 736}
 737
 738#[instrument(err, skip(executor))]
 739async fn sign_out(
 740    session: Session,
 741    mut teardown: watch::Receiver<()>,
 742    executor: Executor,
 743) -> Result<()> {
 744    session.peer.disconnect(session.connection_id);
 745    session
 746        .connection_pool()
 747        .await
 748        .remove_connection(session.connection_id)?;
 749
 750    if let Some(mut left_projects) = session
 751        .db()
 752        .await
 753        .connection_lost(session.connection_id)
 754        .await
 755        .trace_err()
 756    {
 757        for left_project in mem::take(&mut *left_projects) {
 758            project_left(&left_project, &session);
 759        }
 760    }
 761
 762    futures::select_biased! {
 763        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 764            leave_room_for_session(&session).await.trace_err();
 765
 766            if !session
 767                .connection_pool()
 768                .await
 769                .is_user_online(session.user_id)
 770            {
 771                let db = session.db().await;
 772                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 773                    room_updated(&room, &session.peer);
 774                }
 775            }
 776            update_user_contacts(session.user_id, &session).await?;
 777        }
 778        _ = teardown.changed().fuse() => {}
 779    }
 780
 781    Ok(())
 782}
 783
 784async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 785    response.send(proto::Ack {})?;
 786    Ok(())
 787}
 788
 789async fn create_room(
 790    _request: proto::CreateRoom,
 791    response: Response<proto::CreateRoom>,
 792    session: Session,
 793) -> Result<()> {
 794    let live_kit_room = nanoid::nanoid!(30);
 795    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 796        if let Some(_) = live_kit
 797            .create_room(live_kit_room.clone())
 798            .await
 799            .trace_err()
 800        {
 801            if let Some(token) = live_kit
 802                .room_token(&live_kit_room, &session.connection_id.to_string())
 803                .trace_err()
 804            {
 805                Some(proto::LiveKitConnectionInfo {
 806                    server_url: live_kit.url().into(),
 807                    token,
 808                })
 809            } else {
 810                None
 811            }
 812        } else {
 813            None
 814        }
 815    } else {
 816        None
 817    };
 818
 819    {
 820        let room = session
 821            .db()
 822            .await
 823            .create_room(session.user_id, session.connection_id, &live_kit_room)
 824            .await?;
 825
 826        response.send(proto::CreateRoomResponse {
 827            room: Some(room.clone()),
 828            live_kit_connection_info,
 829        })?;
 830    }
 831
 832    update_user_contacts(session.user_id, &session).await?;
 833    Ok(())
 834}
 835
 836async fn join_room(
 837    request: proto::JoinRoom,
 838    response: Response<proto::JoinRoom>,
 839    session: Session,
 840) -> Result<()> {
 841    let room_id = RoomId::from_proto(request.id);
 842    let room = {
 843        let room = session
 844            .db()
 845            .await
 846            .join_room(room_id, session.user_id, session.connection_id)
 847            .await?;
 848        room_updated(&room, &session.peer);
 849        room.clone()
 850    };
 851
 852    for connection_id in session
 853        .connection_pool()
 854        .await
 855        .user_connection_ids(session.user_id)
 856    {
 857        session
 858            .peer
 859            .send(
 860                connection_id,
 861                proto::CallCanceled {
 862                    room_id: room_id.to_proto(),
 863                },
 864            )
 865            .trace_err();
 866    }
 867
 868    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 869        if let Some(token) = live_kit
 870            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 871            .trace_err()
 872        {
 873            Some(proto::LiveKitConnectionInfo {
 874                server_url: live_kit.url().into(),
 875                token,
 876            })
 877        } else {
 878            None
 879        }
 880    } else {
 881        None
 882    };
 883
 884    response.send(proto::JoinRoomResponse {
 885        room: Some(room),
 886        live_kit_connection_info,
 887    })?;
 888
 889    update_user_contacts(session.user_id, &session).await?;
 890    Ok(())
 891}
 892
 893async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 894    leave_room_for_session(&session).await
 895}
 896
 897async fn call(
 898    request: proto::Call,
 899    response: Response<proto::Call>,
 900    session: Session,
 901) -> Result<()> {
 902    let room_id = RoomId::from_proto(request.room_id);
 903    let calling_user_id = session.user_id;
 904    let calling_connection_id = session.connection_id;
 905    let called_user_id = UserId::from_proto(request.called_user_id);
 906    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 907    if !session
 908        .db()
 909        .await
 910        .has_contact(calling_user_id, called_user_id)
 911        .await?
 912    {
 913        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 914    }
 915
 916    let incoming_call = {
 917        let (room, incoming_call) = &mut *session
 918            .db()
 919            .await
 920            .call(
 921                room_id,
 922                calling_user_id,
 923                calling_connection_id,
 924                called_user_id,
 925                initial_project_id,
 926            )
 927            .await?;
 928        room_updated(&room, &session.peer);
 929        mem::take(incoming_call)
 930    };
 931    update_user_contacts(called_user_id, &session).await?;
 932
 933    let mut calls = session
 934        .connection_pool()
 935        .await
 936        .user_connection_ids(called_user_id)
 937        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 938        .collect::<FuturesUnordered<_>>();
 939
 940    while let Some(call_response) = calls.next().await {
 941        match call_response.as_ref() {
 942            Ok(_) => {
 943                response.send(proto::Ack {})?;
 944                return Ok(());
 945            }
 946            Err(_) => {
 947                call_response.trace_err();
 948            }
 949        }
 950    }
 951
 952    {
 953        let room = session
 954            .db()
 955            .await
 956            .call_failed(room_id, called_user_id)
 957            .await?;
 958        room_updated(&room, &session.peer);
 959    }
 960    update_user_contacts(called_user_id, &session).await?;
 961
 962    Err(anyhow!("failed to ring user"))?
 963}
 964
 965async fn cancel_call(
 966    request: proto::CancelCall,
 967    response: Response<proto::CancelCall>,
 968    session: Session,
 969) -> Result<()> {
 970    let called_user_id = UserId::from_proto(request.called_user_id);
 971    let room_id = RoomId::from_proto(request.room_id);
 972    {
 973        let room = session
 974            .db()
 975            .await
 976            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 977            .await?;
 978        room_updated(&room, &session.peer);
 979    }
 980
 981    for connection_id in session
 982        .connection_pool()
 983        .await
 984        .user_connection_ids(called_user_id)
 985    {
 986        session
 987            .peer
 988            .send(
 989                connection_id,
 990                proto::CallCanceled {
 991                    room_id: room_id.to_proto(),
 992                },
 993            )
 994            .trace_err();
 995    }
 996    response.send(proto::Ack {})?;
 997
 998    update_user_contacts(called_user_id, &session).await?;
 999    Ok(())
1000}
1001
1002async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1003    let room_id = RoomId::from_proto(message.room_id);
1004    {
1005        let room = session
1006            .db()
1007            .await
1008            .decline_call(Some(room_id), session.user_id)
1009            .await?;
1010        room_updated(&room, &session.peer);
1011    }
1012
1013    for connection_id in session
1014        .connection_pool()
1015        .await
1016        .user_connection_ids(session.user_id)
1017    {
1018        session
1019            .peer
1020            .send(
1021                connection_id,
1022                proto::CallCanceled {
1023                    room_id: room_id.to_proto(),
1024                },
1025            )
1026            .trace_err();
1027    }
1028    update_user_contacts(session.user_id, &session).await?;
1029    Ok(())
1030}
1031
1032async fn update_participant_location(
1033    request: proto::UpdateParticipantLocation,
1034    response: Response<proto::UpdateParticipantLocation>,
1035    session: Session,
1036) -> Result<()> {
1037    let room_id = RoomId::from_proto(request.room_id);
1038    let location = request
1039        .location
1040        .ok_or_else(|| anyhow!("invalid location"))?;
1041    let room = session
1042        .db()
1043        .await
1044        .update_room_participant_location(room_id, session.connection_id, location)
1045        .await?;
1046    room_updated(&room, &session.peer);
1047    response.send(proto::Ack {})?;
1048    Ok(())
1049}
1050
1051async fn share_project(
1052    request: proto::ShareProject,
1053    response: Response<proto::ShareProject>,
1054    session: Session,
1055) -> Result<()> {
1056    let (project_id, room) = &*session
1057        .db()
1058        .await
1059        .share_project(
1060            RoomId::from_proto(request.room_id),
1061            session.connection_id,
1062            &request.worktrees,
1063        )
1064        .await?;
1065    response.send(proto::ShareProjectResponse {
1066        project_id: project_id.to_proto(),
1067    })?;
1068    room_updated(&room, &session.peer);
1069
1070    Ok(())
1071}
1072
1073async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1074    let project_id = ProjectId::from_proto(message.project_id);
1075
1076    let (room, guest_connection_ids) = &*session
1077        .db()
1078        .await
1079        .unshare_project(project_id, session.connection_id)
1080        .await?;
1081
1082    broadcast(
1083        session.connection_id,
1084        guest_connection_ids.iter().copied(),
1085        |conn_id| session.peer.send(conn_id, message.clone()),
1086    );
1087    room_updated(&room, &session.peer);
1088
1089    Ok(())
1090}
1091
1092async fn join_project(
1093    request: proto::JoinProject,
1094    response: Response<proto::JoinProject>,
1095    session: Session,
1096) -> Result<()> {
1097    let project_id = ProjectId::from_proto(request.project_id);
1098    let guest_user_id = session.user_id;
1099
1100    tracing::info!(%project_id, "join project");
1101
1102    let (project, replica_id) = &mut *session
1103        .db()
1104        .await
1105        .join_project(project_id, session.connection_id)
1106        .await?;
1107
1108    let collaborators = project
1109        .collaborators
1110        .iter()
1111        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1112        .map(|collaborator| proto::Collaborator {
1113            peer_id: collaborator.connection_id as u32,
1114            replica_id: collaborator.replica_id.0 as u32,
1115            user_id: collaborator.user_id.to_proto(),
1116        })
1117        .collect::<Vec<_>>();
1118    let worktrees = project
1119        .worktrees
1120        .iter()
1121        .map(|(id, worktree)| proto::WorktreeMetadata {
1122            id: *id,
1123            root_name: worktree.root_name.clone(),
1124            visible: worktree.visible,
1125            abs_path: worktree.abs_path.clone(),
1126        })
1127        .collect::<Vec<_>>();
1128
1129    for collaborator in &collaborators {
1130        session
1131            .peer
1132            .send(
1133                ConnectionId(collaborator.peer_id),
1134                proto::AddProjectCollaborator {
1135                    project_id: project_id.to_proto(),
1136                    collaborator: Some(proto::Collaborator {
1137                        peer_id: session.connection_id.0,
1138                        replica_id: replica_id.0 as u32,
1139                        user_id: guest_user_id.to_proto(),
1140                    }),
1141                },
1142            )
1143            .trace_err();
1144    }
1145
1146    // First, we send the metadata associated with each worktree.
1147    response.send(proto::JoinProjectResponse {
1148        worktrees: worktrees.clone(),
1149        replica_id: replica_id.0 as u32,
1150        collaborators: collaborators.clone(),
1151        language_servers: project.language_servers.clone(),
1152    })?;
1153
1154    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1155        #[cfg(any(test, feature = "test-support"))]
1156        const MAX_CHUNK_SIZE: usize = 2;
1157        #[cfg(not(any(test, feature = "test-support")))]
1158        const MAX_CHUNK_SIZE: usize = 256;
1159
1160        // Stream this worktree's entries.
1161        let message = proto::UpdateWorktree {
1162            project_id: project_id.to_proto(),
1163            worktree_id,
1164            abs_path: worktree.abs_path.clone(),
1165            root_name: worktree.root_name,
1166            updated_entries: worktree.entries,
1167            removed_entries: Default::default(),
1168            scan_id: worktree.scan_id,
1169            is_last_update: worktree.is_complete,
1170        };
1171        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1172            session.peer.send(session.connection_id, update.clone())?;
1173        }
1174
1175        // Stream this worktree's diagnostics.
1176        for summary in worktree.diagnostic_summaries {
1177            session.peer.send(
1178                session.connection_id,
1179                proto::UpdateDiagnosticSummary {
1180                    project_id: project_id.to_proto(),
1181                    worktree_id: worktree.id,
1182                    summary: Some(summary),
1183                },
1184            )?;
1185        }
1186    }
1187
1188    for language_server in &project.language_servers {
1189        session.peer.send(
1190            session.connection_id,
1191            proto::UpdateLanguageServer {
1192                project_id: project_id.to_proto(),
1193                language_server_id: language_server.id,
1194                variant: Some(
1195                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1196                        proto::LspDiskBasedDiagnosticsUpdated {},
1197                    ),
1198                ),
1199            },
1200        )?;
1201    }
1202
1203    Ok(())
1204}
1205
1206async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1207    let sender_id = session.connection_id;
1208    let project_id = ProjectId::from_proto(request.project_id);
1209
1210    let project = session
1211        .db()
1212        .await
1213        .leave_project(project_id, sender_id)
1214        .await?;
1215    tracing::info!(
1216        %project_id,
1217        host_user_id = %project.host_user_id,
1218        host_connection_id = %project.host_connection_id,
1219        "leave project"
1220    );
1221    project_left(&project, &session);
1222
1223    Ok(())
1224}
1225
1226async fn update_project(
1227    request: proto::UpdateProject,
1228    response: Response<proto::UpdateProject>,
1229    session: Session,
1230) -> Result<()> {
1231    let project_id = ProjectId::from_proto(request.project_id);
1232    let (room, guest_connection_ids) = &*session
1233        .db()
1234        .await
1235        .update_project(project_id, session.connection_id, &request.worktrees)
1236        .await?;
1237    broadcast(
1238        session.connection_id,
1239        guest_connection_ids.iter().copied(),
1240        |connection_id| {
1241            session
1242                .peer
1243                .forward_send(session.connection_id, connection_id, request.clone())
1244        },
1245    );
1246    room_updated(&room, &session.peer);
1247    response.send(proto::Ack {})?;
1248
1249    Ok(())
1250}
1251
1252async fn update_worktree(
1253    request: proto::UpdateWorktree,
1254    response: Response<proto::UpdateWorktree>,
1255    session: Session,
1256) -> Result<()> {
1257    let guest_connection_ids = session
1258        .db()
1259        .await
1260        .update_worktree(&request, session.connection_id)
1261        .await?;
1262
1263    broadcast(
1264        session.connection_id,
1265        guest_connection_ids.iter().copied(),
1266        |connection_id| {
1267            session
1268                .peer
1269                .forward_send(session.connection_id, connection_id, request.clone())
1270        },
1271    );
1272    response.send(proto::Ack {})?;
1273    Ok(())
1274}
1275
1276async fn update_diagnostic_summary(
1277    message: proto::UpdateDiagnosticSummary,
1278    session: Session,
1279) -> Result<()> {
1280    let guest_connection_ids = session
1281        .db()
1282        .await
1283        .update_diagnostic_summary(&message, session.connection_id)
1284        .await?;
1285
1286    broadcast(
1287        session.connection_id,
1288        guest_connection_ids.iter().copied(),
1289        |connection_id| {
1290            session
1291                .peer
1292                .forward_send(session.connection_id, connection_id, message.clone())
1293        },
1294    );
1295
1296    Ok(())
1297}
1298
1299async fn start_language_server(
1300    request: proto::StartLanguageServer,
1301    session: Session,
1302) -> Result<()> {
1303    let guest_connection_ids = session
1304        .db()
1305        .await
1306        .start_language_server(&request, session.connection_id)
1307        .await?;
1308
1309    broadcast(
1310        session.connection_id,
1311        guest_connection_ids.iter().copied(),
1312        |connection_id| {
1313            session
1314                .peer
1315                .forward_send(session.connection_id, connection_id, request.clone())
1316        },
1317    );
1318    Ok(())
1319}
1320
1321async fn update_language_server(
1322    request: proto::UpdateLanguageServer,
1323    session: Session,
1324) -> Result<()> {
1325    let project_id = ProjectId::from_proto(request.project_id);
1326    let project_connection_ids = session
1327        .db()
1328        .await
1329        .project_connection_ids(project_id, session.connection_id)
1330        .await?;
1331    broadcast(
1332        session.connection_id,
1333        project_connection_ids.iter().copied(),
1334        |connection_id| {
1335            session
1336                .peer
1337                .forward_send(session.connection_id, connection_id, request.clone())
1338        },
1339    );
1340    Ok(())
1341}
1342
1343async fn forward_project_request<T>(
1344    request: T,
1345    response: Response<T>,
1346    session: Session,
1347) -> Result<()>
1348where
1349    T: EntityMessage + RequestMessage,
1350{
1351    let project_id = ProjectId::from_proto(request.remote_entity_id());
1352    let host_connection_id = {
1353        let collaborators = session
1354            .db()
1355            .await
1356            .project_collaborators(project_id, session.connection_id)
1357            .await?;
1358        ConnectionId(
1359            collaborators
1360                .iter()
1361                .find(|collaborator| collaborator.is_host)
1362                .ok_or_else(|| anyhow!("host not found"))?
1363                .connection_id as u32,
1364        )
1365    };
1366
1367    let payload = session
1368        .peer
1369        .forward_request(session.connection_id, host_connection_id, request)
1370        .await?;
1371
1372    response.send(payload)?;
1373    Ok(())
1374}
1375
1376async fn save_buffer(
1377    request: proto::SaveBuffer,
1378    response: Response<proto::SaveBuffer>,
1379    session: Session,
1380) -> Result<()> {
1381    let project_id = ProjectId::from_proto(request.project_id);
1382    let host_connection_id = {
1383        let collaborators = session
1384            .db()
1385            .await
1386            .project_collaborators(project_id, session.connection_id)
1387            .await?;
1388        let host = collaborators
1389            .iter()
1390            .find(|collaborator| collaborator.is_host)
1391            .ok_or_else(|| anyhow!("host not found"))?;
1392        ConnectionId(host.connection_id as u32)
1393    };
1394    let response_payload = session
1395        .peer
1396        .forward_request(session.connection_id, host_connection_id, request.clone())
1397        .await?;
1398
1399    let mut collaborators = session
1400        .db()
1401        .await
1402        .project_collaborators(project_id, session.connection_id)
1403        .await?;
1404    collaborators
1405        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1406    let project_connection_ids = collaborators
1407        .iter()
1408        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1409    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1410        session
1411            .peer
1412            .forward_send(host_connection_id, conn_id, response_payload.clone())
1413    });
1414    response.send(response_payload)?;
1415    Ok(())
1416}
1417
1418async fn create_buffer_for_peer(
1419    request: proto::CreateBufferForPeer,
1420    session: Session,
1421) -> Result<()> {
1422    session.peer.forward_send(
1423        session.connection_id,
1424        ConnectionId(request.peer_id),
1425        request,
1426    )?;
1427    Ok(())
1428}
1429
1430async fn update_buffer(
1431    request: proto::UpdateBuffer,
1432    response: Response<proto::UpdateBuffer>,
1433    session: Session,
1434) -> Result<()> {
1435    let project_id = ProjectId::from_proto(request.project_id);
1436    let project_connection_ids = session
1437        .db()
1438        .await
1439        .project_connection_ids(project_id, session.connection_id)
1440        .await?;
1441
1442    broadcast(
1443        session.connection_id,
1444        project_connection_ids.iter().copied(),
1445        |connection_id| {
1446            session
1447                .peer
1448                .forward_send(session.connection_id, connection_id, request.clone())
1449        },
1450    );
1451    response.send(proto::Ack {})?;
1452    Ok(())
1453}
1454
1455async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1456    let project_id = ProjectId::from_proto(request.project_id);
1457    let project_connection_ids = session
1458        .db()
1459        .await
1460        .project_connection_ids(project_id, session.connection_id)
1461        .await?;
1462
1463    broadcast(
1464        session.connection_id,
1465        project_connection_ids.iter().copied(),
1466        |connection_id| {
1467            session
1468                .peer
1469                .forward_send(session.connection_id, connection_id, request.clone())
1470        },
1471    );
1472    Ok(())
1473}
1474
1475async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1476    let project_id = ProjectId::from_proto(request.project_id);
1477    let project_connection_ids = session
1478        .db()
1479        .await
1480        .project_connection_ids(project_id, session.connection_id)
1481        .await?;
1482    broadcast(
1483        session.connection_id,
1484        project_connection_ids.iter().copied(),
1485        |connection_id| {
1486            session
1487                .peer
1488                .forward_send(session.connection_id, connection_id, request.clone())
1489        },
1490    );
1491    Ok(())
1492}
1493
1494async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1495    let project_id = ProjectId::from_proto(request.project_id);
1496    let project_connection_ids = session
1497        .db()
1498        .await
1499        .project_connection_ids(project_id, session.connection_id)
1500        .await?;
1501    broadcast(
1502        session.connection_id,
1503        project_connection_ids.iter().copied(),
1504        |connection_id| {
1505            session
1506                .peer
1507                .forward_send(session.connection_id, connection_id, request.clone())
1508        },
1509    );
1510    Ok(())
1511}
1512
1513async fn follow(
1514    request: proto::Follow,
1515    response: Response<proto::Follow>,
1516    session: Session,
1517) -> Result<()> {
1518    let project_id = ProjectId::from_proto(request.project_id);
1519    let leader_id = ConnectionId(request.leader_id);
1520    let follower_id = session.connection_id;
1521    {
1522        let project_connection_ids = session
1523            .db()
1524            .await
1525            .project_connection_ids(project_id, session.connection_id)
1526            .await?;
1527
1528        if !project_connection_ids.contains(&leader_id) {
1529            Err(anyhow!("no such peer"))?;
1530        }
1531    }
1532
1533    let mut response_payload = session
1534        .peer
1535        .forward_request(session.connection_id, leader_id, request)
1536        .await?;
1537    response_payload
1538        .views
1539        .retain(|view| view.leader_id != Some(follower_id.0));
1540    response.send(response_payload)?;
1541    Ok(())
1542}
1543
1544async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1545    let project_id = ProjectId::from_proto(request.project_id);
1546    let leader_id = ConnectionId(request.leader_id);
1547    let project_connection_ids = session
1548        .db()
1549        .await
1550        .project_connection_ids(project_id, session.connection_id)
1551        .await?;
1552    if !project_connection_ids.contains(&leader_id) {
1553        Err(anyhow!("no such peer"))?;
1554    }
1555    session
1556        .peer
1557        .forward_send(session.connection_id, leader_id, request)?;
1558    Ok(())
1559}
1560
1561async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1562    let project_id = ProjectId::from_proto(request.project_id);
1563    let project_connection_ids = session
1564        .db
1565        .lock()
1566        .await
1567        .project_connection_ids(project_id, session.connection_id)
1568        .await?;
1569
1570    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1571        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1572        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1573        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1574    });
1575    for follower_id in &request.follower_ids {
1576        let follower_id = ConnectionId(*follower_id);
1577        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1578            session
1579                .peer
1580                .forward_send(session.connection_id, follower_id, request.clone())?;
1581        }
1582    }
1583    Ok(())
1584}
1585
1586async fn get_users(
1587    request: proto::GetUsers,
1588    response: Response<proto::GetUsers>,
1589    session: Session,
1590) -> Result<()> {
1591    let user_ids = request
1592        .user_ids
1593        .into_iter()
1594        .map(UserId::from_proto)
1595        .collect();
1596    let users = session
1597        .db()
1598        .await
1599        .get_users_by_ids(user_ids)
1600        .await?
1601        .into_iter()
1602        .map(|user| proto::User {
1603            id: user.id.to_proto(),
1604            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1605            github_login: user.github_login,
1606        })
1607        .collect();
1608    response.send(proto::UsersResponse { users })?;
1609    Ok(())
1610}
1611
1612async fn fuzzy_search_users(
1613    request: proto::FuzzySearchUsers,
1614    response: Response<proto::FuzzySearchUsers>,
1615    session: Session,
1616) -> Result<()> {
1617    let query = request.query;
1618    let users = match query.len() {
1619        0 => vec![],
1620        1 | 2 => session
1621            .db()
1622            .await
1623            .get_user_by_github_account(&query, None)
1624            .await?
1625            .into_iter()
1626            .collect(),
1627        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1628    };
1629    let users = users
1630        .into_iter()
1631        .filter(|user| user.id != session.user_id)
1632        .map(|user| proto::User {
1633            id: user.id.to_proto(),
1634            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1635            github_login: user.github_login,
1636        })
1637        .collect();
1638    response.send(proto::UsersResponse { users })?;
1639    Ok(())
1640}
1641
1642async fn request_contact(
1643    request: proto::RequestContact,
1644    response: Response<proto::RequestContact>,
1645    session: Session,
1646) -> Result<()> {
1647    let requester_id = session.user_id;
1648    let responder_id = UserId::from_proto(request.responder_id);
1649    if requester_id == responder_id {
1650        return Err(anyhow!("cannot add yourself as a contact"))?;
1651    }
1652
1653    session
1654        .db()
1655        .await
1656        .send_contact_request(requester_id, responder_id)
1657        .await?;
1658
1659    // Update outgoing contact requests of requester
1660    let mut update = proto::UpdateContacts::default();
1661    update.outgoing_requests.push(responder_id.to_proto());
1662    for connection_id in session
1663        .connection_pool()
1664        .await
1665        .user_connection_ids(requester_id)
1666    {
1667        session.peer.send(connection_id, update.clone())?;
1668    }
1669
1670    // Update incoming contact requests of responder
1671    let mut update = proto::UpdateContacts::default();
1672    update
1673        .incoming_requests
1674        .push(proto::IncomingContactRequest {
1675            requester_id: requester_id.to_proto(),
1676            should_notify: true,
1677        });
1678    for connection_id in session
1679        .connection_pool()
1680        .await
1681        .user_connection_ids(responder_id)
1682    {
1683        session.peer.send(connection_id, update.clone())?;
1684    }
1685
1686    response.send(proto::Ack {})?;
1687    Ok(())
1688}
1689
1690async fn respond_to_contact_request(
1691    request: proto::RespondToContactRequest,
1692    response: Response<proto::RespondToContactRequest>,
1693    session: Session,
1694) -> Result<()> {
1695    let responder_id = session.user_id;
1696    let requester_id = UserId::from_proto(request.requester_id);
1697    let db = session.db().await;
1698    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1699        db.dismiss_contact_notification(responder_id, requester_id)
1700            .await?;
1701    } else {
1702        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1703
1704        db.respond_to_contact_request(responder_id, requester_id, accept)
1705            .await?;
1706        let requester_busy = db.is_user_busy(requester_id).await?;
1707        let responder_busy = db.is_user_busy(responder_id).await?;
1708
1709        let pool = session.connection_pool().await;
1710        // Update responder with new contact
1711        let mut update = proto::UpdateContacts::default();
1712        if accept {
1713            update
1714                .contacts
1715                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1716        }
1717        update
1718            .remove_incoming_requests
1719            .push(requester_id.to_proto());
1720        for connection_id in pool.user_connection_ids(responder_id) {
1721            session.peer.send(connection_id, update.clone())?;
1722        }
1723
1724        // Update requester with new contact
1725        let mut update = proto::UpdateContacts::default();
1726        if accept {
1727            update
1728                .contacts
1729                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1730        }
1731        update
1732            .remove_outgoing_requests
1733            .push(responder_id.to_proto());
1734        for connection_id in pool.user_connection_ids(requester_id) {
1735            session.peer.send(connection_id, update.clone())?;
1736        }
1737    }
1738
1739    response.send(proto::Ack {})?;
1740    Ok(())
1741}
1742
1743async fn remove_contact(
1744    request: proto::RemoveContact,
1745    response: Response<proto::RemoveContact>,
1746    session: Session,
1747) -> Result<()> {
1748    let requester_id = session.user_id;
1749    let responder_id = UserId::from_proto(request.user_id);
1750    let db = session.db().await;
1751    db.remove_contact(requester_id, responder_id).await?;
1752
1753    let pool = session.connection_pool().await;
1754    // Update outgoing contact requests of requester
1755    let mut update = proto::UpdateContacts::default();
1756    update
1757        .remove_outgoing_requests
1758        .push(responder_id.to_proto());
1759    for connection_id in pool.user_connection_ids(requester_id) {
1760        session.peer.send(connection_id, update.clone())?;
1761    }
1762
1763    // Update incoming contact requests of responder
1764    let mut update = proto::UpdateContacts::default();
1765    update
1766        .remove_incoming_requests
1767        .push(requester_id.to_proto());
1768    for connection_id in pool.user_connection_ids(responder_id) {
1769        session.peer.send(connection_id, update.clone())?;
1770    }
1771
1772    response.send(proto::Ack {})?;
1773    Ok(())
1774}
1775
1776async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1777    let project_id = ProjectId::from_proto(request.project_id);
1778    let project_connection_ids = session
1779        .db()
1780        .await
1781        .project_connection_ids(project_id, session.connection_id)
1782        .await?;
1783    broadcast(
1784        session.connection_id,
1785        project_connection_ids.iter().copied(),
1786        |connection_id| {
1787            session
1788                .peer
1789                .forward_send(session.connection_id, connection_id, request.clone())
1790        },
1791    );
1792    Ok(())
1793}
1794
1795async fn get_private_user_info(
1796    _request: proto::GetPrivateUserInfo,
1797    response: Response<proto::GetPrivateUserInfo>,
1798    session: Session,
1799) -> Result<()> {
1800    let metrics_id = session
1801        .db()
1802        .await
1803        .get_user_metrics_id(session.user_id)
1804        .await?;
1805    let user = session
1806        .db()
1807        .await
1808        .get_user_by_id(session.user_id)
1809        .await?
1810        .ok_or_else(|| anyhow!("user not found"))?;
1811    response.send(proto::GetPrivateUserInfoResponse {
1812        metrics_id,
1813        staff: user.admin,
1814    })?;
1815    Ok(())
1816}
1817
1818fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1819    match message {
1820        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1821        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1822        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1823        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1824        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1825            code: frame.code.into(),
1826            reason: frame.reason,
1827        })),
1828    }
1829}
1830
1831fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1832    match message {
1833        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1834        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1835        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1836        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1837        AxumMessage::Close(frame) => {
1838            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1839                code: frame.code.into(),
1840                reason: frame.reason,
1841            }))
1842        }
1843    }
1844}
1845
1846fn build_initial_contacts_update(
1847    contacts: Vec<db::Contact>,
1848    pool: &ConnectionPool,
1849) -> proto::UpdateContacts {
1850    let mut update = proto::UpdateContacts::default();
1851
1852    for contact in contacts {
1853        match contact {
1854            db::Contact::Accepted {
1855                user_id,
1856                should_notify,
1857                busy,
1858            } => {
1859                update
1860                    .contacts
1861                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1862            }
1863            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1864            db::Contact::Incoming {
1865                user_id,
1866                should_notify,
1867            } => update
1868                .incoming_requests
1869                .push(proto::IncomingContactRequest {
1870                    requester_id: user_id.to_proto(),
1871                    should_notify,
1872                }),
1873        }
1874    }
1875
1876    update
1877}
1878
1879fn contact_for_user(
1880    user_id: UserId,
1881    should_notify: bool,
1882    busy: bool,
1883    pool: &ConnectionPool,
1884) -> proto::Contact {
1885    proto::Contact {
1886        user_id: user_id.to_proto(),
1887        online: pool.is_user_online(user_id),
1888        busy,
1889        should_notify,
1890    }
1891}
1892
1893fn room_updated(room: &proto::Room, peer: &Peer) {
1894    for participant in &room.participants {
1895        peer.send(
1896            ConnectionId(participant.peer_id),
1897            proto::RoomUpdated {
1898                room: Some(room.clone()),
1899            },
1900        )
1901        .trace_err();
1902    }
1903}
1904
1905async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1906    let db = session.db().await;
1907    let contacts = db.get_contacts(user_id).await?;
1908    let busy = db.is_user_busy(user_id).await?;
1909
1910    let pool = session.connection_pool().await;
1911    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1912    for contact in contacts {
1913        if let db::Contact::Accepted {
1914            user_id: contact_user_id,
1915            ..
1916        } = contact
1917        {
1918            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1919                session
1920                    .peer
1921                    .send(
1922                        contact_conn_id,
1923                        proto::UpdateContacts {
1924                            contacts: vec![updated_contact.clone()],
1925                            remove_contacts: Default::default(),
1926                            incoming_requests: Default::default(),
1927                            remove_incoming_requests: Default::default(),
1928                            outgoing_requests: Default::default(),
1929                            remove_outgoing_requests: Default::default(),
1930                        },
1931                    )
1932                    .trace_err();
1933            }
1934        }
1935    }
1936    Ok(())
1937}
1938
1939async fn leave_room_for_session(session: &Session) -> Result<()> {
1940    let mut contacts_to_update = HashSet::default();
1941
1942    let room_id;
1943    let canceled_calls_to_user_ids;
1944    let live_kit_room;
1945    let delete_live_kit_room;
1946    {
1947        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1948        contacts_to_update.insert(session.user_id);
1949
1950        for project in left_room.left_projects.values() {
1951            project_left(project, session);
1952        }
1953
1954        room_updated(&left_room.room, &session.peer);
1955        room_id = RoomId::from_proto(left_room.room.id);
1956        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1957        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1958        delete_live_kit_room = left_room.room.participants.is_empty();
1959    }
1960
1961    {
1962        let pool = session.connection_pool().await;
1963        for canceled_user_id in canceled_calls_to_user_ids {
1964            for connection_id in pool.user_connection_ids(canceled_user_id) {
1965                session
1966                    .peer
1967                    .send(
1968                        connection_id,
1969                        proto::CallCanceled {
1970                            room_id: room_id.to_proto(),
1971                        },
1972                    )
1973                    .trace_err();
1974            }
1975            contacts_to_update.insert(canceled_user_id);
1976        }
1977    }
1978
1979    for contact_user_id in contacts_to_update {
1980        update_user_contacts(contact_user_id, &session).await?;
1981    }
1982
1983    if let Some(live_kit) = session.live_kit_client.as_ref() {
1984        live_kit
1985            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1986            .await
1987            .trace_err();
1988
1989        if delete_live_kit_room {
1990            live_kit.delete_room(live_kit_room).await.trace_err();
1991        }
1992    }
1993
1994    Ok(())
1995}
1996
1997fn project_left(project: &db::LeftProject, session: &Session) {
1998    for connection_id in &project.connection_ids {
1999        if project.host_user_id == session.user_id {
2000            session
2001                .peer
2002                .send(
2003                    *connection_id,
2004                    proto::UnshareProject {
2005                        project_id: project.id.to_proto(),
2006                    },
2007                )
2008                .trace_err();
2009        } else {
2010            session
2011                .peer
2012                .send(
2013                    *connection_id,
2014                    proto::RemoveProjectCollaborator {
2015                        project_id: project.id.to_proto(),
2016                        peer_id: session.connection_id.0,
2017                    },
2018                )
2019                .trace_err();
2020        }
2021    }
2022
2023    session
2024        .peer
2025        .send(
2026            session.connection_id,
2027            proto::UnshareProject {
2028                project_id: project.id.to_proto(),
2029            },
2030        )
2031        .trace_err();
2032}
2033
2034pub trait ResultExt {
2035    type Ok;
2036
2037    fn trace_err(self) -> Option<Self::Ok>;
2038}
2039
2040impl<T, E> ResultExt for Result<T, E>
2041where
2042    E: std::fmt::Debug,
2043{
2044    type Ok = T;
2045
2046    fn trace_err(self) -> Option<T> {
2047        match self {
2048            Ok(value) => Some(value),
2049            Err(error) => {
2050                tracing::error!("{:?}", error);
2051                None
2052            }
2053        }
2054    }
2055}