rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::{watch, Mutex, MutexGuard};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock().await;
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    executor: Executor,
 146    handlers: HashMap<TypeId, MessageHandler>,
 147    teardown: watch::Sender<()>,
 148}
 149
 150pub(crate) struct ConnectionPoolGuard<'a> {
 151    guard: MutexGuard<'a, ConnectionPool>,
 152    _not_send: PhantomData<Rc<()>>,
 153}
 154
 155#[derive(Serialize)]
 156pub struct ServerSnapshot<'a> {
 157    peer: &'a Peer,
 158    #[serde(serialize_with = "serialize_deref")]
 159    connection_pool: ConnectionPoolGuard<'a>,
 160}
 161
 162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 163where
 164    S: Serializer,
 165    T: Deref<Target = U>,
 166    U: Serialize,
 167{
 168    Serialize::serialize(value.deref(), serializer)
 169}
 170
 171impl Server {
 172    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 173        let mut server = Self {
 174            peer: Peer::new(),
 175            app_state,
 176            executor,
 177            connection_pool: Default::default(),
 178            handlers: Default::default(),
 179            teardown: watch::channel(()).0,
 180        };
 181
 182        server
 183            .add_request_handler(ping)
 184            .add_request_handler(create_room)
 185            .add_request_handler(join_room)
 186            .add_message_handler(leave_room)
 187            .add_request_handler(call)
 188            .add_request_handler(cancel_call)
 189            .add_message_handler(decline_call)
 190            .add_request_handler(update_participant_location)
 191            .add_request_handler(share_project)
 192            .add_message_handler(unshare_project)
 193            .add_request_handler(join_project)
 194            .add_message_handler(leave_project)
 195            .add_request_handler(update_project)
 196            .add_request_handler(update_worktree)
 197            .add_message_handler(start_language_server)
 198            .add_message_handler(update_language_server)
 199            .add_message_handler(update_diagnostic_summary)
 200            .add_request_handler(forward_project_request::<proto::GetHover>)
 201            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 202            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 203            .add_request_handler(forward_project_request::<proto::GetReferences>)
 204            .add_request_handler(forward_project_request::<proto::SearchProject>)
 205            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 206            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 209            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 210            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 211            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 212            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 213            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 214            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 215            .add_request_handler(forward_project_request::<proto::PerformRename>)
 216            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 217            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 218            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 221            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 222            .add_message_handler(create_buffer_for_peer)
 223            .add_request_handler(update_buffer)
 224            .add_message_handler(update_buffer_file)
 225            .add_message_handler(buffer_reloaded)
 226            .add_message_handler(buffer_saved)
 227            .add_request_handler(save_buffer)
 228            .add_request_handler(get_users)
 229            .add_request_handler(fuzzy_search_users)
 230            .add_request_handler(request_contact)
 231            .add_request_handler(remove_contact)
 232            .add_request_handler(respond_to_contact_request)
 233            .add_request_handler(follow)
 234            .add_message_handler(unfollow)
 235            .add_message_handler(update_followers)
 236            .add_message_handler(update_diff_base)
 237            .add_request_handler(get_private_user_info);
 238
 239        Arc::new(server)
 240    }
 241
 242    pub async fn start(&self) -> Result<()> {
 243        self.app_state.db.delete_stale_projects().await?;
 244        let db = self.app_state.db.clone();
 245        let peer = self.peer.clone();
 246        let timeout = self.executor.sleep(RECONNECT_TIMEOUT);
 247        let pool = self.connection_pool.clone();
 248        let live_kit_client = self.app_state.live_kit_client.clone();
 249        self.executor.spawn_detached(async move {
 250            timeout.await;
 251            if let Some(room_ids) = db.outdated_room_ids().await.trace_err() {
 252                for room_id in room_ids {
 253                    let mut contacts_to_update = HashSet::default();
 254                    let mut canceled_calls_to_user_ids = Vec::new();
 255                    let mut live_kit_room = String::new();
 256                    let mut delete_live_kit_room = false;
 257
 258                    if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 259                        room_updated(&refreshed_room.room, &peer);
 260                        contacts_to_update
 261                            .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 262                        contacts_to_update
 263                            .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 264                        canceled_calls_to_user_ids =
 265                            mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 266                        live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 267                        delete_live_kit_room = refreshed_room.room.participants.is_empty();
 268                    }
 269
 270                    {
 271                        let pool = pool.lock().await;
 272                        for canceled_user_id in canceled_calls_to_user_ids {
 273                            for connection_id in pool.user_connection_ids(canceled_user_id) {
 274                                peer.send(connection_id, proto::CallCanceled {}).trace_err();
 275                            }
 276                        }
 277                    }
 278
 279                    for user_id in contacts_to_update {
 280                        if let Some((busy, contacts)) = db
 281                            .is_user_busy(user_id)
 282                            .await
 283                            .trace_err()
 284                            .zip(db.get_contacts(user_id).await.trace_err())
 285                        {
 286                            let pool = pool.lock().await;
 287                            let updated_contact = contact_for_user(user_id, false, busy, &pool);
 288                            for contact in contacts {
 289                                if let db::Contact::Accepted {
 290                                    user_id: contact_user_id,
 291                                    ..
 292                                } = contact
 293                                {
 294                                    for contact_conn_id in pool.user_connection_ids(contact_user_id)
 295                                    {
 296                                        peer.send(
 297                                            contact_conn_id,
 298                                            proto::UpdateContacts {
 299                                                contacts: vec![updated_contact.clone()],
 300                                                remove_contacts: Default::default(),
 301                                                incoming_requests: Default::default(),
 302                                                remove_incoming_requests: Default::default(),
 303                                                outgoing_requests: Default::default(),
 304                                                remove_outgoing_requests: Default::default(),
 305                                            },
 306                                        )
 307                                        .trace_err();
 308                                    }
 309                                }
 310                            }
 311                        }
 312                    }
 313
 314                    if let Some(live_kit) = live_kit_client.as_ref() {
 315                        if delete_live_kit_room {
 316                            live_kit.delete_room(live_kit_room).await.trace_err();
 317                        }
 318                    }
 319                }
 320            }
 321        });
 322        Ok(())
 323    }
 324
 325    pub fn teardown(&self) {
 326        self.peer.reset();
 327        let _ = self.teardown.send(());
 328    }
 329
 330    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 331    where
 332        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 333        Fut: 'static + Send + Future<Output = Result<()>>,
 334        M: EnvelopedMessage,
 335    {
 336        let prev_handler = self.handlers.insert(
 337            TypeId::of::<M>(),
 338            Box::new(move |envelope, session| {
 339                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 340                let span = info_span!(
 341                    "handle message",
 342                    payload_type = envelope.payload_type_name()
 343                );
 344                span.in_scope(|| {
 345                    tracing::info!(
 346                        payload_type = envelope.payload_type_name(),
 347                        "message received"
 348                    );
 349                });
 350                let future = (handler)(*envelope, session);
 351                async move {
 352                    if let Err(error) = future.await {
 353                        tracing::error!(%error, "error handling message");
 354                    }
 355                }
 356                .instrument(span)
 357                .boxed()
 358            }),
 359        );
 360        if prev_handler.is_some() {
 361            panic!("registered a handler for the same message twice");
 362        }
 363        self
 364    }
 365
 366    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 367    where
 368        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 369        Fut: 'static + Send + Future<Output = Result<()>>,
 370        M: EnvelopedMessage,
 371    {
 372        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 373        self
 374    }
 375
 376    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 377    where
 378        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 379        Fut: Send + Future<Output = Result<()>>,
 380        M: RequestMessage,
 381    {
 382        let handler = Arc::new(handler);
 383        self.add_handler(move |envelope, session| {
 384            let receipt = envelope.receipt();
 385            let handler = handler.clone();
 386            async move {
 387                let peer = session.peer.clone();
 388                let responded = Arc::new(AtomicBool::default());
 389                let response = Response {
 390                    peer: peer.clone(),
 391                    responded: responded.clone(),
 392                    receipt,
 393                };
 394                match (handler)(envelope.payload, response, session).await {
 395                    Ok(()) => {
 396                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 397                            Ok(())
 398                        } else {
 399                            Err(anyhow!("handler did not send a response"))?
 400                        }
 401                    }
 402                    Err(error) => {
 403                        peer.respond_with_error(
 404                            receipt,
 405                            proto::Error {
 406                                message: error.to_string(),
 407                            },
 408                        )?;
 409                        Err(error)
 410                    }
 411                }
 412            }
 413        })
 414    }
 415
 416    pub fn handle_connection(
 417        self: &Arc<Self>,
 418        connection: Connection,
 419        address: String,
 420        user: User,
 421        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 422        executor: Executor,
 423    ) -> impl Future<Output = Result<()>> {
 424        let this = self.clone();
 425        let user_id = user.id;
 426        let login = user.github_login;
 427        let span = info_span!("handle connection", %user_id, %login, %address);
 428        let mut teardown = self.teardown.subscribe();
 429        async move {
 430            let (connection_id, handle_io, mut incoming_rx) = this
 431                .peer
 432                .add_connection(connection, {
 433                    let executor = executor.clone();
 434                    move |duration| executor.sleep(duration)
 435                });
 436
 437            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 438            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 439            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 440
 441            if let Some(send_connection_id) = send_connection_id.take() {
 442                let _ = send_connection_id.send(connection_id);
 443            }
 444
 445            if !user.connected_once {
 446                this.peer.send(connection_id, proto::ShowContacts {})?;
 447                this.app_state.db.set_user_connected_once(user_id, true).await?;
 448            }
 449
 450            let (contacts, invite_code) = future::try_join(
 451                this.app_state.db.get_contacts(user_id),
 452                this.app_state.db.get_invite_code_for_user(user_id)
 453            ).await?;
 454
 455            {
 456                let mut pool = this.connection_pool.lock().await;
 457                pool.add_connection(connection_id, user_id, user.admin);
 458                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 459
 460                if let Some((code, count)) = invite_code {
 461                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 462                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 463                        count: count as u32,
 464                    })?;
 465                }
 466            }
 467
 468            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 469                this.peer.send(connection_id, incoming_call)?;
 470            }
 471
 472            let session = Session {
 473                user_id,
 474                connection_id,
 475                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 476                peer: this.peer.clone(),
 477                connection_pool: this.connection_pool.clone(),
 478                live_kit_client: this.app_state.live_kit_client.clone()
 479            };
 480            update_user_contacts(user_id, &session).await?;
 481
 482            let handle_io = handle_io.fuse();
 483            futures::pin_mut!(handle_io);
 484
 485            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 486            // This prevents deadlocks when e.g., client A performs a request to client B and
 487            // client B performs a request to client A. If both clients stop processing further
 488            // messages until their respective request completes, they won't have a chance to
 489            // respond to the other client's request and cause a deadlock.
 490            //
 491            // This arrangement ensures we will attempt to process earlier messages first, but fall
 492            // back to processing messages arrived later in the spirit of making progress.
 493            let mut foreground_message_handlers = FuturesUnordered::new();
 494            loop {
 495                let next_message = incoming_rx.next().fuse();
 496                futures::pin_mut!(next_message);
 497                futures::select_biased! {
 498                    _ = teardown.changed().fuse() => return Ok(()),
 499                    result = handle_io => {
 500                        if let Err(error) = result {
 501                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 502                        }
 503                        break;
 504                    }
 505                    _ = foreground_message_handlers.next() => {}
 506                    message = next_message => {
 507                        if let Some(message) = message {
 508                            let type_name = message.payload_type_name();
 509                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 510                            let span_enter = span.enter();
 511                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 512                                let is_background = message.is_background();
 513                                let handle_message = (handler)(message, session.clone());
 514                                drop(span_enter);
 515
 516                                let handle_message = handle_message.instrument(span);
 517                                if is_background {
 518                                    executor.spawn_detached(handle_message);
 519                                } else {
 520                                    foreground_message_handlers.push(handle_message);
 521                                }
 522                            } else {
 523                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 524                            }
 525                        } else {
 526                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 527                            break;
 528                        }
 529                    }
 530                }
 531            }
 532
 533            drop(foreground_message_handlers);
 534            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 535            if let Err(error) = sign_out(session, teardown, executor).await {
 536                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 537            }
 538
 539            Ok(())
 540        }.instrument(span)
 541    }
 542
 543    pub async fn invite_code_redeemed(
 544        self: &Arc<Self>,
 545        inviter_id: UserId,
 546        invitee_id: UserId,
 547    ) -> Result<()> {
 548        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 549            if let Some(code) = &user.invite_code {
 550                let pool = self.connection_pool.lock().await;
 551                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 552                for connection_id in pool.user_connection_ids(inviter_id) {
 553                    self.peer.send(
 554                        connection_id,
 555                        proto::UpdateContacts {
 556                            contacts: vec![invitee_contact.clone()],
 557                            ..Default::default()
 558                        },
 559                    )?;
 560                    self.peer.send(
 561                        connection_id,
 562                        proto::UpdateInviteInfo {
 563                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 564                            count: user.invite_count as u32,
 565                        },
 566                    )?;
 567                }
 568            }
 569        }
 570        Ok(())
 571    }
 572
 573    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 574        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 575            if let Some(invite_code) = &user.invite_code {
 576                let pool = self.connection_pool.lock().await;
 577                for connection_id in pool.user_connection_ids(user_id) {
 578                    self.peer.send(
 579                        connection_id,
 580                        proto::UpdateInviteInfo {
 581                            url: format!(
 582                                "{}{}",
 583                                self.app_state.config.invite_link_prefix, invite_code
 584                            ),
 585                            count: user.invite_count as u32,
 586                        },
 587                    )?;
 588                }
 589            }
 590        }
 591        Ok(())
 592    }
 593
 594    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 595        ServerSnapshot {
 596            connection_pool: ConnectionPoolGuard {
 597                guard: self.connection_pool.lock().await,
 598                _not_send: PhantomData,
 599            },
 600            peer: &self.peer,
 601        }
 602    }
 603}
 604
 605impl<'a> Deref for ConnectionPoolGuard<'a> {
 606    type Target = ConnectionPool;
 607
 608    fn deref(&self) -> &Self::Target {
 609        &*self.guard
 610    }
 611}
 612
 613impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 614    fn deref_mut(&mut self) -> &mut Self::Target {
 615        &mut *self.guard
 616    }
 617}
 618
 619impl<'a> Drop for ConnectionPoolGuard<'a> {
 620    fn drop(&mut self) {
 621        #[cfg(test)]
 622        self.check_invariants();
 623    }
 624}
 625
 626fn broadcast<F>(
 627    sender_id: ConnectionId,
 628    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 629    mut f: F,
 630) where
 631    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 632{
 633    for receiver_id in receiver_ids {
 634        if receiver_id != sender_id {
 635            f(receiver_id).trace_err();
 636        }
 637    }
 638}
 639
 640lazy_static! {
 641    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 642}
 643
 644pub struct ProtocolVersion(u32);
 645
 646impl Header for ProtocolVersion {
 647    fn name() -> &'static HeaderName {
 648        &ZED_PROTOCOL_VERSION
 649    }
 650
 651    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 652    where
 653        Self: Sized,
 654        I: Iterator<Item = &'i axum::http::HeaderValue>,
 655    {
 656        let version = values
 657            .next()
 658            .ok_or_else(axum::headers::Error::invalid)?
 659            .to_str()
 660            .map_err(|_| axum::headers::Error::invalid())?
 661            .parse()
 662            .map_err(|_| axum::headers::Error::invalid())?;
 663        Ok(Self(version))
 664    }
 665
 666    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 667        values.extend([self.0.to_string().parse().unwrap()]);
 668    }
 669}
 670
 671pub fn routes(server: Arc<Server>) -> Router<Body> {
 672    Router::new()
 673        .route("/rpc", get(handle_websocket_request))
 674        .layer(
 675            ServiceBuilder::new()
 676                .layer(Extension(server.app_state.clone()))
 677                .layer(middleware::from_fn(auth::validate_header)),
 678        )
 679        .route("/metrics", get(handle_metrics))
 680        .layer(Extension(server))
 681}
 682
 683pub async fn handle_websocket_request(
 684    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 685    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 686    Extension(server): Extension<Arc<Server>>,
 687    Extension(user): Extension<User>,
 688    ws: WebSocketUpgrade,
 689) -> axum::response::Response {
 690    if protocol_version != rpc::PROTOCOL_VERSION {
 691        return (
 692            StatusCode::UPGRADE_REQUIRED,
 693            "client must be upgraded".to_string(),
 694        )
 695            .into_response();
 696    }
 697    let socket_address = socket_address.to_string();
 698    ws.on_upgrade(move |socket| {
 699        use util::ResultExt;
 700        let socket = socket
 701            .map_ok(to_tungstenite_message)
 702            .err_into()
 703            .with(|message| async move { Ok(to_axum_message(message)) });
 704        let connection = Connection::new(Box::pin(socket));
 705        async move {
 706            server
 707                .handle_connection(connection, socket_address, user, None, Executor::Production)
 708                .await
 709                .log_err();
 710        }
 711    })
 712}
 713
 714pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 715    let connections = server
 716        .connection_pool
 717        .lock()
 718        .await
 719        .connections()
 720        .filter(|connection| !connection.admin)
 721        .count();
 722
 723    METRIC_CONNECTIONS.set(connections as _);
 724
 725    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 726    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 727
 728    let encoder = prometheus::TextEncoder::new();
 729    let metric_families = prometheus::gather();
 730    let encoded_metrics = encoder
 731        .encode_to_string(&metric_families)
 732        .map_err(|err| anyhow!("{}", err))?;
 733    Ok(encoded_metrics)
 734}
 735
 736#[instrument(err, skip(executor))]
 737async fn sign_out(
 738    session: Session,
 739    mut teardown: watch::Receiver<()>,
 740    executor: Executor,
 741) -> Result<()> {
 742    session.peer.disconnect(session.connection_id);
 743    session
 744        .connection_pool()
 745        .await
 746        .remove_connection(session.connection_id)?;
 747
 748    if let Some(mut left_projects) = session
 749        .db()
 750        .await
 751        .connection_lost(session.connection_id)
 752        .await
 753        .trace_err()
 754    {
 755        for left_project in mem::take(&mut *left_projects) {
 756            project_left(&left_project, &session);
 757        }
 758    }
 759
 760    futures::select_biased! {
 761        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 762            leave_room_for_session(&session).await.trace_err();
 763
 764            if !session
 765                .connection_pool()
 766                .await
 767                .is_user_online(session.user_id)
 768            {
 769                let db = session.db().await;
 770                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 771                    room_updated(&room, &session.peer);
 772                }
 773            }
 774            update_user_contacts(session.user_id, &session).await?;
 775        }
 776        _ = teardown.changed().fuse() => {}
 777    }
 778
 779    Ok(())
 780}
 781
 782async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 783    response.send(proto::Ack {})?;
 784    Ok(())
 785}
 786
 787async fn create_room(
 788    _request: proto::CreateRoom,
 789    response: Response<proto::CreateRoom>,
 790    session: Session,
 791) -> Result<()> {
 792    let live_kit_room = nanoid::nanoid!(30);
 793    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 794        if let Some(_) = live_kit
 795            .create_room(live_kit_room.clone())
 796            .await
 797            .trace_err()
 798        {
 799            if let Some(token) = live_kit
 800                .room_token(&live_kit_room, &session.connection_id.to_string())
 801                .trace_err()
 802            {
 803                Some(proto::LiveKitConnectionInfo {
 804                    server_url: live_kit.url().into(),
 805                    token,
 806                })
 807            } else {
 808                None
 809            }
 810        } else {
 811            None
 812        }
 813    } else {
 814        None
 815    };
 816
 817    {
 818        let room = session
 819            .db()
 820            .await
 821            .create_room(session.user_id, session.connection_id, &live_kit_room)
 822            .await?;
 823
 824        response.send(proto::CreateRoomResponse {
 825            room: Some(room.clone()),
 826            live_kit_connection_info,
 827        })?;
 828    }
 829
 830    update_user_contacts(session.user_id, &session).await?;
 831    Ok(())
 832}
 833
 834async fn join_room(
 835    request: proto::JoinRoom,
 836    response: Response<proto::JoinRoom>,
 837    session: Session,
 838) -> Result<()> {
 839    let room = {
 840        let room = session
 841            .db()
 842            .await
 843            .join_room(
 844                RoomId::from_proto(request.id),
 845                session.user_id,
 846                session.connection_id,
 847            )
 848            .await?;
 849        room_updated(&room, &session.peer);
 850        room.clone()
 851    };
 852
 853    for connection_id in session
 854        .connection_pool()
 855        .await
 856        .user_connection_ids(session.user_id)
 857    {
 858        session
 859            .peer
 860            .send(connection_id, proto::CallCanceled {})
 861            .trace_err();
 862    }
 863
 864    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 865        if let Some(token) = live_kit
 866            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 867            .trace_err()
 868        {
 869            Some(proto::LiveKitConnectionInfo {
 870                server_url: live_kit.url().into(),
 871                token,
 872            })
 873        } else {
 874            None
 875        }
 876    } else {
 877        None
 878    };
 879
 880    response.send(proto::JoinRoomResponse {
 881        room: Some(room),
 882        live_kit_connection_info,
 883    })?;
 884
 885    update_user_contacts(session.user_id, &session).await?;
 886    Ok(())
 887}
 888
 889async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 890    leave_room_for_session(&session).await
 891}
 892
 893async fn call(
 894    request: proto::Call,
 895    response: Response<proto::Call>,
 896    session: Session,
 897) -> Result<()> {
 898    let room_id = RoomId::from_proto(request.room_id);
 899    let calling_user_id = session.user_id;
 900    let calling_connection_id = session.connection_id;
 901    let called_user_id = UserId::from_proto(request.called_user_id);
 902    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 903    if !session
 904        .db()
 905        .await
 906        .has_contact(calling_user_id, called_user_id)
 907        .await?
 908    {
 909        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 910    }
 911
 912    let incoming_call = {
 913        let (room, incoming_call) = &mut *session
 914            .db()
 915            .await
 916            .call(
 917                room_id,
 918                calling_user_id,
 919                calling_connection_id,
 920                called_user_id,
 921                initial_project_id,
 922            )
 923            .await?;
 924        room_updated(&room, &session.peer);
 925        mem::take(incoming_call)
 926    };
 927    update_user_contacts(called_user_id, &session).await?;
 928
 929    let mut calls = session
 930        .connection_pool()
 931        .await
 932        .user_connection_ids(called_user_id)
 933        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 934        .collect::<FuturesUnordered<_>>();
 935
 936    while let Some(call_response) = calls.next().await {
 937        match call_response.as_ref() {
 938            Ok(_) => {
 939                response.send(proto::Ack {})?;
 940                return Ok(());
 941            }
 942            Err(_) => {
 943                call_response.trace_err();
 944            }
 945        }
 946    }
 947
 948    {
 949        let room = session
 950            .db()
 951            .await
 952            .call_failed(room_id, called_user_id)
 953            .await?;
 954        room_updated(&room, &session.peer);
 955    }
 956    update_user_contacts(called_user_id, &session).await?;
 957
 958    Err(anyhow!("failed to ring user"))?
 959}
 960
 961async fn cancel_call(
 962    request: proto::CancelCall,
 963    response: Response<proto::CancelCall>,
 964    session: Session,
 965) -> Result<()> {
 966    let called_user_id = UserId::from_proto(request.called_user_id);
 967    let room_id = RoomId::from_proto(request.room_id);
 968    {
 969        let room = session
 970            .db()
 971            .await
 972            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 973            .await?;
 974        room_updated(&room, &session.peer);
 975    }
 976
 977    for connection_id in session
 978        .connection_pool()
 979        .await
 980        .user_connection_ids(called_user_id)
 981    {
 982        session
 983            .peer
 984            .send(connection_id, proto::CallCanceled {})
 985            .trace_err();
 986    }
 987    response.send(proto::Ack {})?;
 988
 989    update_user_contacts(called_user_id, &session).await?;
 990    Ok(())
 991}
 992
 993async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
 994    let room_id = RoomId::from_proto(message.room_id);
 995    {
 996        let room = session
 997            .db()
 998            .await
 999            .decline_call(Some(room_id), session.user_id)
1000            .await?;
1001        room_updated(&room, &session.peer);
1002    }
1003
1004    for connection_id in session
1005        .connection_pool()
1006        .await
1007        .user_connection_ids(session.user_id)
1008    {
1009        session
1010            .peer
1011            .send(connection_id, proto::CallCanceled {})
1012            .trace_err();
1013    }
1014    update_user_contacts(session.user_id, &session).await?;
1015    Ok(())
1016}
1017
1018async fn update_participant_location(
1019    request: proto::UpdateParticipantLocation,
1020    response: Response<proto::UpdateParticipantLocation>,
1021    session: Session,
1022) -> Result<()> {
1023    let room_id = RoomId::from_proto(request.room_id);
1024    let location = request
1025        .location
1026        .ok_or_else(|| anyhow!("invalid location"))?;
1027    let room = session
1028        .db()
1029        .await
1030        .update_room_participant_location(room_id, session.connection_id, location)
1031        .await?;
1032    room_updated(&room, &session.peer);
1033    response.send(proto::Ack {})?;
1034    Ok(())
1035}
1036
1037async fn share_project(
1038    request: proto::ShareProject,
1039    response: Response<proto::ShareProject>,
1040    session: Session,
1041) -> Result<()> {
1042    let (project_id, room) = &*session
1043        .db()
1044        .await
1045        .share_project(
1046            RoomId::from_proto(request.room_id),
1047            session.connection_id,
1048            &request.worktrees,
1049        )
1050        .await?;
1051    response.send(proto::ShareProjectResponse {
1052        project_id: project_id.to_proto(),
1053    })?;
1054    room_updated(&room, &session.peer);
1055
1056    Ok(())
1057}
1058
1059async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1060    let project_id = ProjectId::from_proto(message.project_id);
1061
1062    let (room, guest_connection_ids) = &*session
1063        .db()
1064        .await
1065        .unshare_project(project_id, session.connection_id)
1066        .await?;
1067
1068    broadcast(
1069        session.connection_id,
1070        guest_connection_ids.iter().copied(),
1071        |conn_id| session.peer.send(conn_id, message.clone()),
1072    );
1073    room_updated(&room, &session.peer);
1074
1075    Ok(())
1076}
1077
1078async fn join_project(
1079    request: proto::JoinProject,
1080    response: Response<proto::JoinProject>,
1081    session: Session,
1082) -> Result<()> {
1083    let project_id = ProjectId::from_proto(request.project_id);
1084    let guest_user_id = session.user_id;
1085
1086    tracing::info!(%project_id, "join project");
1087
1088    let (project, replica_id) = &mut *session
1089        .db()
1090        .await
1091        .join_project(project_id, session.connection_id)
1092        .await?;
1093
1094    let collaborators = project
1095        .collaborators
1096        .iter()
1097        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1098        .map(|collaborator| proto::Collaborator {
1099            peer_id: collaborator.connection_id as u32,
1100            replica_id: collaborator.replica_id.0 as u32,
1101            user_id: collaborator.user_id.to_proto(),
1102        })
1103        .collect::<Vec<_>>();
1104    let worktrees = project
1105        .worktrees
1106        .iter()
1107        .map(|(id, worktree)| proto::WorktreeMetadata {
1108            id: *id,
1109            root_name: worktree.root_name.clone(),
1110            visible: worktree.visible,
1111            abs_path: worktree.abs_path.clone(),
1112        })
1113        .collect::<Vec<_>>();
1114
1115    for collaborator in &collaborators {
1116        session
1117            .peer
1118            .send(
1119                ConnectionId(collaborator.peer_id),
1120                proto::AddProjectCollaborator {
1121                    project_id: project_id.to_proto(),
1122                    collaborator: Some(proto::Collaborator {
1123                        peer_id: session.connection_id.0,
1124                        replica_id: replica_id.0 as u32,
1125                        user_id: guest_user_id.to_proto(),
1126                    }),
1127                },
1128            )
1129            .trace_err();
1130    }
1131
1132    // First, we send the metadata associated with each worktree.
1133    response.send(proto::JoinProjectResponse {
1134        worktrees: worktrees.clone(),
1135        replica_id: replica_id.0 as u32,
1136        collaborators: collaborators.clone(),
1137        language_servers: project.language_servers.clone(),
1138    })?;
1139
1140    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1141        #[cfg(any(test, feature = "test-support"))]
1142        const MAX_CHUNK_SIZE: usize = 2;
1143        #[cfg(not(any(test, feature = "test-support")))]
1144        const MAX_CHUNK_SIZE: usize = 256;
1145
1146        // Stream this worktree's entries.
1147        let message = proto::UpdateWorktree {
1148            project_id: project_id.to_proto(),
1149            worktree_id,
1150            abs_path: worktree.abs_path.clone(),
1151            root_name: worktree.root_name,
1152            updated_entries: worktree.entries,
1153            removed_entries: Default::default(),
1154            scan_id: worktree.scan_id,
1155            is_last_update: worktree.is_complete,
1156        };
1157        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1158            session.peer.send(session.connection_id, update.clone())?;
1159        }
1160
1161        // Stream this worktree's diagnostics.
1162        for summary in worktree.diagnostic_summaries {
1163            session.peer.send(
1164                session.connection_id,
1165                proto::UpdateDiagnosticSummary {
1166                    project_id: project_id.to_proto(),
1167                    worktree_id: worktree.id,
1168                    summary: Some(summary),
1169                },
1170            )?;
1171        }
1172    }
1173
1174    for language_server in &project.language_servers {
1175        session.peer.send(
1176            session.connection_id,
1177            proto::UpdateLanguageServer {
1178                project_id: project_id.to_proto(),
1179                language_server_id: language_server.id,
1180                variant: Some(
1181                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1182                        proto::LspDiskBasedDiagnosticsUpdated {},
1183                    ),
1184                ),
1185            },
1186        )?;
1187    }
1188
1189    Ok(())
1190}
1191
1192async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1193    let sender_id = session.connection_id;
1194    let project_id = ProjectId::from_proto(request.project_id);
1195
1196    let project = session
1197        .db()
1198        .await
1199        .leave_project(project_id, sender_id)
1200        .await?;
1201    tracing::info!(
1202        %project_id,
1203        host_user_id = %project.host_user_id,
1204        host_connection_id = %project.host_connection_id,
1205        "leave project"
1206    );
1207    project_left(&project, &session);
1208
1209    Ok(())
1210}
1211
1212async fn update_project(
1213    request: proto::UpdateProject,
1214    response: Response<proto::UpdateProject>,
1215    session: Session,
1216) -> Result<()> {
1217    let project_id = ProjectId::from_proto(request.project_id);
1218    let (room, guest_connection_ids) = &*session
1219        .db()
1220        .await
1221        .update_project(project_id, session.connection_id, &request.worktrees)
1222        .await?;
1223    broadcast(
1224        session.connection_id,
1225        guest_connection_ids.iter().copied(),
1226        |connection_id| {
1227            session
1228                .peer
1229                .forward_send(session.connection_id, connection_id, request.clone())
1230        },
1231    );
1232    room_updated(&room, &session.peer);
1233    response.send(proto::Ack {})?;
1234
1235    Ok(())
1236}
1237
1238async fn update_worktree(
1239    request: proto::UpdateWorktree,
1240    response: Response<proto::UpdateWorktree>,
1241    session: Session,
1242) -> Result<()> {
1243    let guest_connection_ids = session
1244        .db()
1245        .await
1246        .update_worktree(&request, session.connection_id)
1247        .await?;
1248
1249    broadcast(
1250        session.connection_id,
1251        guest_connection_ids.iter().copied(),
1252        |connection_id| {
1253            session
1254                .peer
1255                .forward_send(session.connection_id, connection_id, request.clone())
1256        },
1257    );
1258    response.send(proto::Ack {})?;
1259    Ok(())
1260}
1261
1262async fn update_diagnostic_summary(
1263    message: proto::UpdateDiagnosticSummary,
1264    session: Session,
1265) -> Result<()> {
1266    let guest_connection_ids = session
1267        .db()
1268        .await
1269        .update_diagnostic_summary(&message, session.connection_id)
1270        .await?;
1271
1272    broadcast(
1273        session.connection_id,
1274        guest_connection_ids.iter().copied(),
1275        |connection_id| {
1276            session
1277                .peer
1278                .forward_send(session.connection_id, connection_id, message.clone())
1279        },
1280    );
1281
1282    Ok(())
1283}
1284
1285async fn start_language_server(
1286    request: proto::StartLanguageServer,
1287    session: Session,
1288) -> Result<()> {
1289    let guest_connection_ids = session
1290        .db()
1291        .await
1292        .start_language_server(&request, session.connection_id)
1293        .await?;
1294
1295    broadcast(
1296        session.connection_id,
1297        guest_connection_ids.iter().copied(),
1298        |connection_id| {
1299            session
1300                .peer
1301                .forward_send(session.connection_id, connection_id, request.clone())
1302        },
1303    );
1304    Ok(())
1305}
1306
1307async fn update_language_server(
1308    request: proto::UpdateLanguageServer,
1309    session: Session,
1310) -> Result<()> {
1311    let project_id = ProjectId::from_proto(request.project_id);
1312    let project_connection_ids = session
1313        .db()
1314        .await
1315        .project_connection_ids(project_id, session.connection_id)
1316        .await?;
1317    broadcast(
1318        session.connection_id,
1319        project_connection_ids.iter().copied(),
1320        |connection_id| {
1321            session
1322                .peer
1323                .forward_send(session.connection_id, connection_id, request.clone())
1324        },
1325    );
1326    Ok(())
1327}
1328
1329async fn forward_project_request<T>(
1330    request: T,
1331    response: Response<T>,
1332    session: Session,
1333) -> Result<()>
1334where
1335    T: EntityMessage + RequestMessage,
1336{
1337    let project_id = ProjectId::from_proto(request.remote_entity_id());
1338    let host_connection_id = {
1339        let collaborators = session
1340            .db()
1341            .await
1342            .project_collaborators(project_id, session.connection_id)
1343            .await?;
1344        ConnectionId(
1345            collaborators
1346                .iter()
1347                .find(|collaborator| collaborator.is_host)
1348                .ok_or_else(|| anyhow!("host not found"))?
1349                .connection_id as u32,
1350        )
1351    };
1352
1353    let payload = session
1354        .peer
1355        .forward_request(session.connection_id, host_connection_id, request)
1356        .await?;
1357
1358    response.send(payload)?;
1359    Ok(())
1360}
1361
1362async fn save_buffer(
1363    request: proto::SaveBuffer,
1364    response: Response<proto::SaveBuffer>,
1365    session: Session,
1366) -> Result<()> {
1367    let project_id = ProjectId::from_proto(request.project_id);
1368    let host_connection_id = {
1369        let collaborators = session
1370            .db()
1371            .await
1372            .project_collaborators(project_id, session.connection_id)
1373            .await?;
1374        let host = collaborators
1375            .iter()
1376            .find(|collaborator| collaborator.is_host)
1377            .ok_or_else(|| anyhow!("host not found"))?;
1378        ConnectionId(host.connection_id as u32)
1379    };
1380    let response_payload = session
1381        .peer
1382        .forward_request(session.connection_id, host_connection_id, request.clone())
1383        .await?;
1384
1385    let mut collaborators = session
1386        .db()
1387        .await
1388        .project_collaborators(project_id, session.connection_id)
1389        .await?;
1390    collaborators
1391        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1392    let project_connection_ids = collaborators
1393        .iter()
1394        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1395    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1396        session
1397            .peer
1398            .forward_send(host_connection_id, conn_id, response_payload.clone())
1399    });
1400    response.send(response_payload)?;
1401    Ok(())
1402}
1403
1404async fn create_buffer_for_peer(
1405    request: proto::CreateBufferForPeer,
1406    session: Session,
1407) -> Result<()> {
1408    session.peer.forward_send(
1409        session.connection_id,
1410        ConnectionId(request.peer_id),
1411        request,
1412    )?;
1413    Ok(())
1414}
1415
1416async fn update_buffer(
1417    request: proto::UpdateBuffer,
1418    response: Response<proto::UpdateBuffer>,
1419    session: Session,
1420) -> Result<()> {
1421    let project_id = ProjectId::from_proto(request.project_id);
1422    let project_connection_ids = session
1423        .db()
1424        .await
1425        .project_connection_ids(project_id, session.connection_id)
1426        .await?;
1427
1428    broadcast(
1429        session.connection_id,
1430        project_connection_ids.iter().copied(),
1431        |connection_id| {
1432            session
1433                .peer
1434                .forward_send(session.connection_id, connection_id, request.clone())
1435        },
1436    );
1437    response.send(proto::Ack {})?;
1438    Ok(())
1439}
1440
1441async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1442    let project_id = ProjectId::from_proto(request.project_id);
1443    let project_connection_ids = session
1444        .db()
1445        .await
1446        .project_connection_ids(project_id, session.connection_id)
1447        .await?;
1448
1449    broadcast(
1450        session.connection_id,
1451        project_connection_ids.iter().copied(),
1452        |connection_id| {
1453            session
1454                .peer
1455                .forward_send(session.connection_id, connection_id, request.clone())
1456        },
1457    );
1458    Ok(())
1459}
1460
1461async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1462    let project_id = ProjectId::from_proto(request.project_id);
1463    let project_connection_ids = session
1464        .db()
1465        .await
1466        .project_connection_ids(project_id, session.connection_id)
1467        .await?;
1468    broadcast(
1469        session.connection_id,
1470        project_connection_ids.iter().copied(),
1471        |connection_id| {
1472            session
1473                .peer
1474                .forward_send(session.connection_id, connection_id, request.clone())
1475        },
1476    );
1477    Ok(())
1478}
1479
1480async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1481    let project_id = ProjectId::from_proto(request.project_id);
1482    let project_connection_ids = session
1483        .db()
1484        .await
1485        .project_connection_ids(project_id, session.connection_id)
1486        .await?;
1487    broadcast(
1488        session.connection_id,
1489        project_connection_ids.iter().copied(),
1490        |connection_id| {
1491            session
1492                .peer
1493                .forward_send(session.connection_id, connection_id, request.clone())
1494        },
1495    );
1496    Ok(())
1497}
1498
1499async fn follow(
1500    request: proto::Follow,
1501    response: Response<proto::Follow>,
1502    session: Session,
1503) -> Result<()> {
1504    let project_id = ProjectId::from_proto(request.project_id);
1505    let leader_id = ConnectionId(request.leader_id);
1506    let follower_id = session.connection_id;
1507    {
1508        let project_connection_ids = session
1509            .db()
1510            .await
1511            .project_connection_ids(project_id, session.connection_id)
1512            .await?;
1513
1514        if !project_connection_ids.contains(&leader_id) {
1515            Err(anyhow!("no such peer"))?;
1516        }
1517    }
1518
1519    let mut response_payload = session
1520        .peer
1521        .forward_request(session.connection_id, leader_id, request)
1522        .await?;
1523    response_payload
1524        .views
1525        .retain(|view| view.leader_id != Some(follower_id.0));
1526    response.send(response_payload)?;
1527    Ok(())
1528}
1529
1530async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1531    let project_id = ProjectId::from_proto(request.project_id);
1532    let leader_id = ConnectionId(request.leader_id);
1533    let project_connection_ids = session
1534        .db()
1535        .await
1536        .project_connection_ids(project_id, session.connection_id)
1537        .await?;
1538    if !project_connection_ids.contains(&leader_id) {
1539        Err(anyhow!("no such peer"))?;
1540    }
1541    session
1542        .peer
1543        .forward_send(session.connection_id, leader_id, request)?;
1544    Ok(())
1545}
1546
1547async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1548    let project_id = ProjectId::from_proto(request.project_id);
1549    let project_connection_ids = session
1550        .db
1551        .lock()
1552        .await
1553        .project_connection_ids(project_id, session.connection_id)
1554        .await?;
1555
1556    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1557        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1558        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1559        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1560    });
1561    for follower_id in &request.follower_ids {
1562        let follower_id = ConnectionId(*follower_id);
1563        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1564            session
1565                .peer
1566                .forward_send(session.connection_id, follower_id, request.clone())?;
1567        }
1568    }
1569    Ok(())
1570}
1571
1572async fn get_users(
1573    request: proto::GetUsers,
1574    response: Response<proto::GetUsers>,
1575    session: Session,
1576) -> Result<()> {
1577    let user_ids = request
1578        .user_ids
1579        .into_iter()
1580        .map(UserId::from_proto)
1581        .collect();
1582    let users = session
1583        .db()
1584        .await
1585        .get_users_by_ids(user_ids)
1586        .await?
1587        .into_iter()
1588        .map(|user| proto::User {
1589            id: user.id.to_proto(),
1590            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1591            github_login: user.github_login,
1592        })
1593        .collect();
1594    response.send(proto::UsersResponse { users })?;
1595    Ok(())
1596}
1597
1598async fn fuzzy_search_users(
1599    request: proto::FuzzySearchUsers,
1600    response: Response<proto::FuzzySearchUsers>,
1601    session: Session,
1602) -> Result<()> {
1603    let query = request.query;
1604    let users = match query.len() {
1605        0 => vec![],
1606        1 | 2 => session
1607            .db()
1608            .await
1609            .get_user_by_github_account(&query, None)
1610            .await?
1611            .into_iter()
1612            .collect(),
1613        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1614    };
1615    let users = users
1616        .into_iter()
1617        .filter(|user| user.id != session.user_id)
1618        .map(|user| proto::User {
1619            id: user.id.to_proto(),
1620            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1621            github_login: user.github_login,
1622        })
1623        .collect();
1624    response.send(proto::UsersResponse { users })?;
1625    Ok(())
1626}
1627
1628async fn request_contact(
1629    request: proto::RequestContact,
1630    response: Response<proto::RequestContact>,
1631    session: Session,
1632) -> Result<()> {
1633    let requester_id = session.user_id;
1634    let responder_id = UserId::from_proto(request.responder_id);
1635    if requester_id == responder_id {
1636        return Err(anyhow!("cannot add yourself as a contact"))?;
1637    }
1638
1639    session
1640        .db()
1641        .await
1642        .send_contact_request(requester_id, responder_id)
1643        .await?;
1644
1645    // Update outgoing contact requests of requester
1646    let mut update = proto::UpdateContacts::default();
1647    update.outgoing_requests.push(responder_id.to_proto());
1648    for connection_id in session
1649        .connection_pool()
1650        .await
1651        .user_connection_ids(requester_id)
1652    {
1653        session.peer.send(connection_id, update.clone())?;
1654    }
1655
1656    // Update incoming contact requests of responder
1657    let mut update = proto::UpdateContacts::default();
1658    update
1659        .incoming_requests
1660        .push(proto::IncomingContactRequest {
1661            requester_id: requester_id.to_proto(),
1662            should_notify: true,
1663        });
1664    for connection_id in session
1665        .connection_pool()
1666        .await
1667        .user_connection_ids(responder_id)
1668    {
1669        session.peer.send(connection_id, update.clone())?;
1670    }
1671
1672    response.send(proto::Ack {})?;
1673    Ok(())
1674}
1675
1676async fn respond_to_contact_request(
1677    request: proto::RespondToContactRequest,
1678    response: Response<proto::RespondToContactRequest>,
1679    session: Session,
1680) -> Result<()> {
1681    let responder_id = session.user_id;
1682    let requester_id = UserId::from_proto(request.requester_id);
1683    let db = session.db().await;
1684    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1685        db.dismiss_contact_notification(responder_id, requester_id)
1686            .await?;
1687    } else {
1688        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1689
1690        db.respond_to_contact_request(responder_id, requester_id, accept)
1691            .await?;
1692        let requester_busy = db.is_user_busy(requester_id).await?;
1693        let responder_busy = db.is_user_busy(responder_id).await?;
1694
1695        let pool = session.connection_pool().await;
1696        // Update responder with new contact
1697        let mut update = proto::UpdateContacts::default();
1698        if accept {
1699            update
1700                .contacts
1701                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1702        }
1703        update
1704            .remove_incoming_requests
1705            .push(requester_id.to_proto());
1706        for connection_id in pool.user_connection_ids(responder_id) {
1707            session.peer.send(connection_id, update.clone())?;
1708        }
1709
1710        // Update requester with new contact
1711        let mut update = proto::UpdateContacts::default();
1712        if accept {
1713            update
1714                .contacts
1715                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1716        }
1717        update
1718            .remove_outgoing_requests
1719            .push(responder_id.to_proto());
1720        for connection_id in pool.user_connection_ids(requester_id) {
1721            session.peer.send(connection_id, update.clone())?;
1722        }
1723    }
1724
1725    response.send(proto::Ack {})?;
1726    Ok(())
1727}
1728
1729async fn remove_contact(
1730    request: proto::RemoveContact,
1731    response: Response<proto::RemoveContact>,
1732    session: Session,
1733) -> Result<()> {
1734    let requester_id = session.user_id;
1735    let responder_id = UserId::from_proto(request.user_id);
1736    let db = session.db().await;
1737    db.remove_contact(requester_id, responder_id).await?;
1738
1739    let pool = session.connection_pool().await;
1740    // Update outgoing contact requests of requester
1741    let mut update = proto::UpdateContacts::default();
1742    update
1743        .remove_outgoing_requests
1744        .push(responder_id.to_proto());
1745    for connection_id in pool.user_connection_ids(requester_id) {
1746        session.peer.send(connection_id, update.clone())?;
1747    }
1748
1749    // Update incoming contact requests of responder
1750    let mut update = proto::UpdateContacts::default();
1751    update
1752        .remove_incoming_requests
1753        .push(requester_id.to_proto());
1754    for connection_id in pool.user_connection_ids(responder_id) {
1755        session.peer.send(connection_id, update.clone())?;
1756    }
1757
1758    response.send(proto::Ack {})?;
1759    Ok(())
1760}
1761
1762async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1763    let project_id = ProjectId::from_proto(request.project_id);
1764    let project_connection_ids = session
1765        .db()
1766        .await
1767        .project_connection_ids(project_id, session.connection_id)
1768        .await?;
1769    broadcast(
1770        session.connection_id,
1771        project_connection_ids.iter().copied(),
1772        |connection_id| {
1773            session
1774                .peer
1775                .forward_send(session.connection_id, connection_id, request.clone())
1776        },
1777    );
1778    Ok(())
1779}
1780
1781async fn get_private_user_info(
1782    _request: proto::GetPrivateUserInfo,
1783    response: Response<proto::GetPrivateUserInfo>,
1784    session: Session,
1785) -> Result<()> {
1786    let metrics_id = session
1787        .db()
1788        .await
1789        .get_user_metrics_id(session.user_id)
1790        .await?;
1791    let user = session
1792        .db()
1793        .await
1794        .get_user_by_id(session.user_id)
1795        .await?
1796        .ok_or_else(|| anyhow!("user not found"))?;
1797    response.send(proto::GetPrivateUserInfoResponse {
1798        metrics_id,
1799        staff: user.admin,
1800    })?;
1801    Ok(())
1802}
1803
1804fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1805    match message {
1806        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1807        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1808        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1809        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1810        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1811            code: frame.code.into(),
1812            reason: frame.reason,
1813        })),
1814    }
1815}
1816
1817fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1818    match message {
1819        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1820        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1821        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1822        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1823        AxumMessage::Close(frame) => {
1824            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1825                code: frame.code.into(),
1826                reason: frame.reason,
1827            }))
1828        }
1829    }
1830}
1831
1832fn build_initial_contacts_update(
1833    contacts: Vec<db::Contact>,
1834    pool: &ConnectionPool,
1835) -> proto::UpdateContacts {
1836    let mut update = proto::UpdateContacts::default();
1837
1838    for contact in contacts {
1839        match contact {
1840            db::Contact::Accepted {
1841                user_id,
1842                should_notify,
1843                busy,
1844            } => {
1845                update
1846                    .contacts
1847                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1848            }
1849            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1850            db::Contact::Incoming {
1851                user_id,
1852                should_notify,
1853            } => update
1854                .incoming_requests
1855                .push(proto::IncomingContactRequest {
1856                    requester_id: user_id.to_proto(),
1857                    should_notify,
1858                }),
1859        }
1860    }
1861
1862    update
1863}
1864
1865fn contact_for_user(
1866    user_id: UserId,
1867    should_notify: bool,
1868    busy: bool,
1869    pool: &ConnectionPool,
1870) -> proto::Contact {
1871    proto::Contact {
1872        user_id: user_id.to_proto(),
1873        online: pool.is_user_online(user_id),
1874        busy,
1875        should_notify,
1876    }
1877}
1878
1879fn room_updated(room: &proto::Room, peer: &Peer) {
1880    for participant in &room.participants {
1881        peer.send(
1882            ConnectionId(participant.peer_id),
1883            proto::RoomUpdated {
1884                room: Some(room.clone()),
1885            },
1886        )
1887        .trace_err();
1888    }
1889}
1890
1891async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1892    let db = session.db().await;
1893    let contacts = db.get_contacts(user_id).await?;
1894    let busy = db.is_user_busy(user_id).await?;
1895
1896    let pool = session.connection_pool().await;
1897    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1898    for contact in contacts {
1899        if let db::Contact::Accepted {
1900            user_id: contact_user_id,
1901            ..
1902        } = contact
1903        {
1904            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1905                session
1906                    .peer
1907                    .send(
1908                        contact_conn_id,
1909                        proto::UpdateContacts {
1910                            contacts: vec![updated_contact.clone()],
1911                            remove_contacts: Default::default(),
1912                            incoming_requests: Default::default(),
1913                            remove_incoming_requests: Default::default(),
1914                            outgoing_requests: Default::default(),
1915                            remove_outgoing_requests: Default::default(),
1916                        },
1917                    )
1918                    .trace_err();
1919            }
1920        }
1921    }
1922    Ok(())
1923}
1924
1925async fn leave_room_for_session(session: &Session) -> Result<()> {
1926    let mut contacts_to_update = HashSet::default();
1927
1928    let canceled_calls_to_user_ids;
1929    let live_kit_room;
1930    let delete_live_kit_room;
1931    {
1932        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1933        contacts_to_update.insert(session.user_id);
1934
1935        for project in left_room.left_projects.values() {
1936            project_left(project, session);
1937        }
1938
1939        room_updated(&left_room.room, &session.peer);
1940        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1941        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1942        delete_live_kit_room = left_room.room.participants.is_empty();
1943    }
1944
1945    {
1946        let pool = session.connection_pool().await;
1947        for canceled_user_id in canceled_calls_to_user_ids {
1948            for connection_id in pool.user_connection_ids(canceled_user_id) {
1949                session
1950                    .peer
1951                    .send(connection_id, proto::CallCanceled {})
1952                    .trace_err();
1953            }
1954            contacts_to_update.insert(canceled_user_id);
1955        }
1956    }
1957
1958    for contact_user_id in contacts_to_update {
1959        update_user_contacts(contact_user_id, &session).await?;
1960    }
1961
1962    if let Some(live_kit) = session.live_kit_client.as_ref() {
1963        live_kit
1964            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1965            .await
1966            .trace_err();
1967
1968        if delete_live_kit_room {
1969            live_kit.delete_room(live_kit_room).await.trace_err();
1970        }
1971    }
1972
1973    Ok(())
1974}
1975
1976fn project_left(project: &db::LeftProject, session: &Session) {
1977    for connection_id in &project.connection_ids {
1978        if project.host_user_id == session.user_id {
1979            session
1980                .peer
1981                .send(
1982                    *connection_id,
1983                    proto::UnshareProject {
1984                        project_id: project.id.to_proto(),
1985                    },
1986                )
1987                .trace_err();
1988        } else {
1989            session
1990                .peer
1991                .send(
1992                    *connection_id,
1993                    proto::RemoveProjectCollaborator {
1994                        project_id: project.id.to_proto(),
1995                        peer_id: session.connection_id.0,
1996                    },
1997                )
1998                .trace_err();
1999        }
2000    }
2001
2002    session
2003        .peer
2004        .send(
2005            session.connection_id,
2006            proto::UnshareProject {
2007                project_id: project.id.to_proto(),
2008            },
2009        )
2010        .trace_err();
2011}
2012
2013pub trait ResultExt {
2014    type Ok;
2015
2016    fn trace_err(self) -> Option<Self::Ok>;
2017}
2018
2019impl<T, E> ResultExt for Result<T, E>
2020where
2021    E: std::fmt::Debug,
2022{
2023    type Ok = T;
2024
2025    fn trace_err(self) -> Option<T> {
2026        match self {
2027            Ok(value) => Some(value),
2028            Err(error) => {
2029                tracing::error!("{:?}", error);
2030                None
2031            }
2032        }
2033    }
2034}