rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    id: parking_lot::Mutex<ServerId>,
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    executor: Executor,
 146    handlers: HashMap<TypeId, MessageHandler>,
 147    teardown: watch::Sender<()>,
 148}
 149
 150pub(crate) struct ConnectionPoolGuard<'a> {
 151    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 152    _not_send: PhantomData<Rc<()>>,
 153}
 154
 155#[derive(Serialize)]
 156pub struct ServerSnapshot<'a> {
 157    peer: &'a Peer,
 158    #[serde(serialize_with = "serialize_deref")]
 159    connection_pool: ConnectionPoolGuard<'a>,
 160}
 161
 162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 163where
 164    S: Serializer,
 165    T: Deref<Target = U>,
 166    U: Serialize,
 167{
 168    Serialize::serialize(value.deref(), serializer)
 169}
 170
 171impl Server {
 172    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 173        let mut server = Self {
 174            id: parking_lot::Mutex::new(id),
 175            peer: Peer::new(id.0 as u32),
 176            app_state,
 177            executor,
 178            connection_pool: Default::default(),
 179            handlers: Default::default(),
 180            teardown: watch::channel(()).0,
 181        };
 182
 183        server
 184            .add_request_handler(ping)
 185            .add_request_handler(create_room)
 186            .add_request_handler(join_room)
 187            .add_message_handler(leave_room)
 188            .add_request_handler(call)
 189            .add_request_handler(cancel_call)
 190            .add_message_handler(decline_call)
 191            .add_request_handler(update_participant_location)
 192            .add_request_handler(share_project)
 193            .add_message_handler(unshare_project)
 194            .add_request_handler(join_project)
 195            .add_message_handler(leave_project)
 196            .add_request_handler(update_project)
 197            .add_request_handler(update_worktree)
 198            .add_message_handler(start_language_server)
 199            .add_message_handler(update_language_server)
 200            .add_message_handler(update_diagnostic_summary)
 201            .add_request_handler(forward_project_request::<proto::GetHover>)
 202            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 203            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 204            .add_request_handler(forward_project_request::<proto::GetReferences>)
 205            .add_request_handler(forward_project_request::<proto::SearchProject>)
 206            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 207            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 209            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 210            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 211            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 213            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 214            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 215            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 216            .add_request_handler(forward_project_request::<proto::PerformRename>)
 217            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 218            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 219            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 221            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 222            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 223            .add_message_handler(create_buffer_for_peer)
 224            .add_request_handler(update_buffer)
 225            .add_message_handler(update_buffer_file)
 226            .add_message_handler(buffer_reloaded)
 227            .add_message_handler(buffer_saved)
 228            .add_request_handler(save_buffer)
 229            .add_request_handler(get_users)
 230            .add_request_handler(fuzzy_search_users)
 231            .add_request_handler(request_contact)
 232            .add_request_handler(remove_contact)
 233            .add_request_handler(respond_to_contact_request)
 234            .add_request_handler(follow)
 235            .add_message_handler(unfollow)
 236            .add_message_handler(update_followers)
 237            .add_message_handler(update_diff_base)
 238            .add_request_handler(get_private_user_info);
 239
 240        Arc::new(server)
 241    }
 242
 243    pub async fn start(&self) -> Result<()> {
 244        let server_id = *self.id.lock();
 245        let app_state = self.app_state.clone();
 246        let peer = self.peer.clone();
 247        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 248        let pool = self.connection_pool.clone();
 249        let live_kit_client = self.app_state.live_kit_client.clone();
 250
 251        let span = info_span!("start server");
 252        let span_enter = span.enter();
 253
 254        tracing::info!("begin deleting stale projects");
 255        app_state
 256            .db
 257            .delete_stale_projects(&app_state.config.zed_environment, server_id)
 258            .await?;
 259        tracing::info!("finish deleting stale projects");
 260
 261        drop(span_enter);
 262        self.executor.spawn_detached(
 263            async move {
 264                tracing::info!("waiting for cleanup timeout");
 265                timeout.await;
 266                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 267                if let Some(room_ids) = app_state
 268                    .db
 269                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 270                    .await
 271                    .trace_err()
 272                {
 273                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 274                    for room_id in room_ids {
 275                        let mut contacts_to_update = HashSet::default();
 276                        let mut canceled_calls_to_user_ids = Vec::new();
 277                        let mut live_kit_room = String::new();
 278                        let mut delete_live_kit_room = false;
 279
 280                        if let Ok(mut refreshed_room) =
 281                            app_state.db.refresh_room(room_id, server_id).await
 282                        {
 283                            tracing::info!(
 284                                room_id = room_id.0,
 285                                new_participant_count = refreshed_room.room.participants.len(),
 286                                "refreshed room"
 287                            );
 288                            room_updated(&refreshed_room.room, &peer);
 289                            contacts_to_update
 290                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 291                            contacts_to_update
 292                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 293                            canceled_calls_to_user_ids =
 294                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 295                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 296                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 297                        }
 298
 299                        {
 300                            let pool = pool.lock();
 301                            for canceled_user_id in canceled_calls_to_user_ids {
 302                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 303                                    peer.send(
 304                                        connection_id,
 305                                        proto::CallCanceled {
 306                                            room_id: room_id.to_proto(),
 307                                        },
 308                                    )
 309                                    .trace_err();
 310                                }
 311                            }
 312                        }
 313
 314                        for user_id in contacts_to_update {
 315                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 316                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 317                            if let Some((busy, contacts)) = busy.zip(contacts) {
 318                                let pool = pool.lock();
 319                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 320                                for contact in contacts {
 321                                    if let db::Contact::Accepted {
 322                                        user_id: contact_user_id,
 323                                        ..
 324                                    } = contact
 325                                    {
 326                                        for contact_conn_id in
 327                                            pool.user_connection_ids(contact_user_id)
 328                                        {
 329                                            peer.send(
 330                                                contact_conn_id,
 331                                                proto::UpdateContacts {
 332                                                    contacts: vec![updated_contact.clone()],
 333                                                    remove_contacts: Default::default(),
 334                                                    incoming_requests: Default::default(),
 335                                                    remove_incoming_requests: Default::default(),
 336                                                    outgoing_requests: Default::default(),
 337                                                    remove_outgoing_requests: Default::default(),
 338                                                },
 339                                            )
 340                                            .trace_err();
 341                                        }
 342                                    }
 343                                }
 344                            }
 345                        }
 346
 347                        if let Some(live_kit) = live_kit_client.as_ref() {
 348                            if delete_live_kit_room {
 349                                live_kit.delete_room(live_kit_room).await.trace_err();
 350                            }
 351                        }
 352                    }
 353                }
 354
 355                app_state
 356                    .db
 357                    .delete_stale_servers(server_id, &app_state.config.zed_environment)
 358                    .await
 359                    .trace_err();
 360            }
 361            .instrument(span),
 362        );
 363        Ok(())
 364    }
 365
 366    pub fn teardown(&self) {
 367        self.peer.teardown();
 368        self.connection_pool.lock().reset();
 369        let _ = self.teardown.send(());
 370    }
 371
 372    #[cfg(test)]
 373    pub fn reset(&self, id: ServerId) {
 374        self.teardown();
 375        *self.id.lock() = id;
 376        self.peer.reset(id.0 as u32);
 377    }
 378
 379    #[cfg(test)]
 380    pub fn id(&self) -> ServerId {
 381        *self.id.lock()
 382    }
 383
 384    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 385    where
 386        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 387        Fut: 'static + Send + Future<Output = Result<()>>,
 388        M: EnvelopedMessage,
 389    {
 390        let prev_handler = self.handlers.insert(
 391            TypeId::of::<M>(),
 392            Box::new(move |envelope, session| {
 393                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 394                let span = info_span!(
 395                    "handle message",
 396                    payload_type = envelope.payload_type_name()
 397                );
 398                span.in_scope(|| {
 399                    tracing::info!(
 400                        payload_type = envelope.payload_type_name(),
 401                        "message received"
 402                    );
 403                });
 404                let future = (handler)(*envelope, session);
 405                async move {
 406                    if let Err(error) = future.await {
 407                        tracing::error!(%error, "error handling message");
 408                    }
 409                }
 410                .instrument(span)
 411                .boxed()
 412            }),
 413        );
 414        if prev_handler.is_some() {
 415            panic!("registered a handler for the same message twice");
 416        }
 417        self
 418    }
 419
 420    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 421    where
 422        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 423        Fut: 'static + Send + Future<Output = Result<()>>,
 424        M: EnvelopedMessage,
 425    {
 426        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 427        self
 428    }
 429
 430    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 431    where
 432        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 433        Fut: Send + Future<Output = Result<()>>,
 434        M: RequestMessage,
 435    {
 436        let handler = Arc::new(handler);
 437        self.add_handler(move |envelope, session| {
 438            let receipt = envelope.receipt();
 439            let handler = handler.clone();
 440            async move {
 441                let peer = session.peer.clone();
 442                let responded = Arc::new(AtomicBool::default());
 443                let response = Response {
 444                    peer: peer.clone(),
 445                    responded: responded.clone(),
 446                    receipt,
 447                };
 448                match (handler)(envelope.payload, response, session).await {
 449                    Ok(()) => {
 450                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 451                            Ok(())
 452                        } else {
 453                            Err(anyhow!("handler did not send a response"))?
 454                        }
 455                    }
 456                    Err(error) => {
 457                        peer.respond_with_error(
 458                            receipt,
 459                            proto::Error {
 460                                message: error.to_string(),
 461                            },
 462                        )?;
 463                        Err(error)
 464                    }
 465                }
 466            }
 467        })
 468    }
 469
 470    pub fn handle_connection(
 471        self: &Arc<Self>,
 472        connection: Connection,
 473        address: String,
 474        user: User,
 475        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 476        executor: Executor,
 477    ) -> impl Future<Output = Result<()>> {
 478        let this = self.clone();
 479        let user_id = user.id;
 480        let login = user.github_login;
 481        let span = info_span!("handle connection", %user_id, %login, %address);
 482        let mut teardown = self.teardown.subscribe();
 483        async move {
 484            let (connection_id, handle_io, mut incoming_rx) = this
 485                .peer
 486                .add_connection(connection, {
 487                    let executor = executor.clone();
 488                    move |duration| executor.sleep(duration)
 489                });
 490
 491            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 492            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 493            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 494
 495            if let Some(send_connection_id) = send_connection_id.take() {
 496                let _ = send_connection_id.send(connection_id);
 497            }
 498
 499            if !user.connected_once {
 500                this.peer.send(connection_id, proto::ShowContacts {})?;
 501                this.app_state.db.set_user_connected_once(user_id, true).await?;
 502            }
 503
 504            let (contacts, invite_code) = future::try_join(
 505                this.app_state.db.get_contacts(user_id),
 506                this.app_state.db.get_invite_code_for_user(user_id)
 507            ).await?;
 508
 509            {
 510                let mut pool = this.connection_pool.lock();
 511                pool.add_connection(connection_id, user_id, user.admin);
 512                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 513
 514                if let Some((code, count)) = invite_code {
 515                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 516                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 517                        count: count as u32,
 518                    })?;
 519                }
 520            }
 521
 522            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 523                this.peer.send(connection_id, incoming_call)?;
 524            }
 525
 526            let session = Session {
 527                user_id,
 528                connection_id,
 529                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 530                peer: this.peer.clone(),
 531                connection_pool: this.connection_pool.clone(),
 532                live_kit_client: this.app_state.live_kit_client.clone()
 533            };
 534            update_user_contacts(user_id, &session).await?;
 535
 536            let handle_io = handle_io.fuse();
 537            futures::pin_mut!(handle_io);
 538
 539            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 540            // This prevents deadlocks when e.g., client A performs a request to client B and
 541            // client B performs a request to client A. If both clients stop processing further
 542            // messages until their respective request completes, they won't have a chance to
 543            // respond to the other client's request and cause a deadlock.
 544            //
 545            // This arrangement ensures we will attempt to process earlier messages first, but fall
 546            // back to processing messages arrived later in the spirit of making progress.
 547            let mut foreground_message_handlers = FuturesUnordered::new();
 548            loop {
 549                let next_message = incoming_rx.next().fuse();
 550                futures::pin_mut!(next_message);
 551                futures::select_biased! {
 552                    _ = teardown.changed().fuse() => return Ok(()),
 553                    result = handle_io => {
 554                        if let Err(error) = result {
 555                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 556                        }
 557                        break;
 558                    }
 559                    _ = foreground_message_handlers.next() => {}
 560                    message = next_message => {
 561                        if let Some(message) = message {
 562                            let type_name = message.payload_type_name();
 563                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 564                            let span_enter = span.enter();
 565                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 566                                let is_background = message.is_background();
 567                                let handle_message = (handler)(message, session.clone());
 568                                drop(span_enter);
 569
 570                                let handle_message = handle_message.instrument(span);
 571                                if is_background {
 572                                    executor.spawn_detached(handle_message);
 573                                } else {
 574                                    foreground_message_handlers.push(handle_message);
 575                                }
 576                            } else {
 577                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 578                            }
 579                        } else {
 580                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 581                            break;
 582                        }
 583                    }
 584                }
 585            }
 586
 587            drop(foreground_message_handlers);
 588            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 589            if let Err(error) = sign_out(session, teardown, executor).await {
 590                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 591            }
 592
 593            Ok(())
 594        }.instrument(span)
 595    }
 596
 597    pub async fn invite_code_redeemed(
 598        self: &Arc<Self>,
 599        inviter_id: UserId,
 600        invitee_id: UserId,
 601    ) -> Result<()> {
 602        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 603            if let Some(code) = &user.invite_code {
 604                let pool = self.connection_pool.lock();
 605                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 606                for connection_id in pool.user_connection_ids(inviter_id) {
 607                    self.peer.send(
 608                        connection_id,
 609                        proto::UpdateContacts {
 610                            contacts: vec![invitee_contact.clone()],
 611                            ..Default::default()
 612                        },
 613                    )?;
 614                    self.peer.send(
 615                        connection_id,
 616                        proto::UpdateInviteInfo {
 617                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 618                            count: user.invite_count as u32,
 619                        },
 620                    )?;
 621                }
 622            }
 623        }
 624        Ok(())
 625    }
 626
 627    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 628        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 629            if let Some(invite_code) = &user.invite_code {
 630                let pool = self.connection_pool.lock();
 631                for connection_id in pool.user_connection_ids(user_id) {
 632                    self.peer.send(
 633                        connection_id,
 634                        proto::UpdateInviteInfo {
 635                            url: format!(
 636                                "{}{}",
 637                                self.app_state.config.invite_link_prefix, invite_code
 638                            ),
 639                            count: user.invite_count as u32,
 640                        },
 641                    )?;
 642                }
 643            }
 644        }
 645        Ok(())
 646    }
 647
 648    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 649        ServerSnapshot {
 650            connection_pool: ConnectionPoolGuard {
 651                guard: self.connection_pool.lock(),
 652                _not_send: PhantomData,
 653            },
 654            peer: &self.peer,
 655        }
 656    }
 657}
 658
 659impl<'a> Deref for ConnectionPoolGuard<'a> {
 660    type Target = ConnectionPool;
 661
 662    fn deref(&self) -> &Self::Target {
 663        &*self.guard
 664    }
 665}
 666
 667impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 668    fn deref_mut(&mut self) -> &mut Self::Target {
 669        &mut *self.guard
 670    }
 671}
 672
 673impl<'a> Drop for ConnectionPoolGuard<'a> {
 674    fn drop(&mut self) {
 675        #[cfg(test)]
 676        self.check_invariants();
 677    }
 678}
 679
 680fn broadcast<F>(
 681    sender_id: ConnectionId,
 682    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 683    mut f: F,
 684) where
 685    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 686{
 687    for receiver_id in receiver_ids {
 688        if receiver_id != sender_id {
 689            f(receiver_id).trace_err();
 690        }
 691    }
 692}
 693
 694lazy_static! {
 695    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 696}
 697
 698pub struct ProtocolVersion(u32);
 699
 700impl Header for ProtocolVersion {
 701    fn name() -> &'static HeaderName {
 702        &ZED_PROTOCOL_VERSION
 703    }
 704
 705    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 706    where
 707        Self: Sized,
 708        I: Iterator<Item = &'i axum::http::HeaderValue>,
 709    {
 710        let version = values
 711            .next()
 712            .ok_or_else(axum::headers::Error::invalid)?
 713            .to_str()
 714            .map_err(|_| axum::headers::Error::invalid())?
 715            .parse()
 716            .map_err(|_| axum::headers::Error::invalid())?;
 717        Ok(Self(version))
 718    }
 719
 720    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 721        values.extend([self.0.to_string().parse().unwrap()]);
 722    }
 723}
 724
 725pub fn routes(server: Arc<Server>) -> Router<Body> {
 726    Router::new()
 727        .route("/rpc", get(handle_websocket_request))
 728        .layer(
 729            ServiceBuilder::new()
 730                .layer(Extension(server.app_state.clone()))
 731                .layer(middleware::from_fn(auth::validate_header)),
 732        )
 733        .route("/metrics", get(handle_metrics))
 734        .layer(Extension(server))
 735}
 736
 737pub async fn handle_websocket_request(
 738    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 739    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 740    Extension(server): Extension<Arc<Server>>,
 741    Extension(user): Extension<User>,
 742    ws: WebSocketUpgrade,
 743) -> axum::response::Response {
 744    if protocol_version != rpc::PROTOCOL_VERSION {
 745        return (
 746            StatusCode::UPGRADE_REQUIRED,
 747            "client must be upgraded".to_string(),
 748        )
 749            .into_response();
 750    }
 751    let socket_address = socket_address.to_string();
 752    ws.on_upgrade(move |socket| {
 753        use util::ResultExt;
 754        let socket = socket
 755            .map_ok(to_tungstenite_message)
 756            .err_into()
 757            .with(|message| async move { Ok(to_axum_message(message)) });
 758        let connection = Connection::new(Box::pin(socket));
 759        async move {
 760            server
 761                .handle_connection(connection, socket_address, user, None, Executor::Production)
 762                .await
 763                .log_err();
 764        }
 765    })
 766}
 767
 768pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 769    let connections = server
 770        .connection_pool
 771        .lock()
 772        .connections()
 773        .filter(|connection| !connection.admin)
 774        .count();
 775
 776    METRIC_CONNECTIONS.set(connections as _);
 777
 778    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 779    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 780
 781    let encoder = prometheus::TextEncoder::new();
 782    let metric_families = prometheus::gather();
 783    let encoded_metrics = encoder
 784        .encode_to_string(&metric_families)
 785        .map_err(|err| anyhow!("{}", err))?;
 786    Ok(encoded_metrics)
 787}
 788
 789#[instrument(err, skip(executor))]
 790async fn sign_out(
 791    session: Session,
 792    mut teardown: watch::Receiver<()>,
 793    executor: Executor,
 794) -> Result<()> {
 795    session.peer.disconnect(session.connection_id);
 796    session
 797        .connection_pool()
 798        .await
 799        .remove_connection(session.connection_id)?;
 800
 801    if let Some(mut left_projects) = session
 802        .db()
 803        .await
 804        .connection_lost(session.connection_id)
 805        .await
 806        .trace_err()
 807    {
 808        for left_project in mem::take(&mut *left_projects) {
 809            project_left(&left_project, &session);
 810        }
 811    }
 812
 813    futures::select_biased! {
 814        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 815            leave_room_for_session(&session).await.trace_err();
 816
 817            if !session
 818                .connection_pool()
 819                .await
 820                .is_user_online(session.user_id)
 821            {
 822                let db = session.db().await;
 823                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 824                    room_updated(&room, &session.peer);
 825                }
 826            }
 827            update_user_contacts(session.user_id, &session).await?;
 828        }
 829        _ = teardown.changed().fuse() => {}
 830    }
 831
 832    Ok(())
 833}
 834
 835async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 836    response.send(proto::Ack {})?;
 837    Ok(())
 838}
 839
 840async fn create_room(
 841    _request: proto::CreateRoom,
 842    response: Response<proto::CreateRoom>,
 843    session: Session,
 844) -> Result<()> {
 845    let live_kit_room = nanoid::nanoid!(30);
 846    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 847        if let Some(_) = live_kit
 848            .create_room(live_kit_room.clone())
 849            .await
 850            .trace_err()
 851        {
 852            if let Some(token) = live_kit
 853                .room_token(&live_kit_room, &session.user_id.to_string())
 854                .trace_err()
 855            {
 856                Some(proto::LiveKitConnectionInfo {
 857                    server_url: live_kit.url().into(),
 858                    token,
 859                })
 860            } else {
 861                None
 862            }
 863        } else {
 864            None
 865        }
 866    } else {
 867        None
 868    };
 869
 870    {
 871        let room = session
 872            .db()
 873            .await
 874            .create_room(session.user_id, session.connection_id, &live_kit_room)
 875            .await?;
 876
 877        response.send(proto::CreateRoomResponse {
 878            room: Some(room.clone()),
 879            live_kit_connection_info,
 880        })?;
 881    }
 882
 883    update_user_contacts(session.user_id, &session).await?;
 884    Ok(())
 885}
 886
 887async fn join_room(
 888    request: proto::JoinRoom,
 889    response: Response<proto::JoinRoom>,
 890    session: Session,
 891) -> Result<()> {
 892    let room_id = RoomId::from_proto(request.id);
 893    let room = {
 894        let room = session
 895            .db()
 896            .await
 897            .join_room(room_id, session.user_id, session.connection_id)
 898            .await?;
 899        room_updated(&room, &session.peer);
 900        room.clone()
 901    };
 902
 903    for connection_id in session
 904        .connection_pool()
 905        .await
 906        .user_connection_ids(session.user_id)
 907    {
 908        session
 909            .peer
 910            .send(
 911                connection_id,
 912                proto::CallCanceled {
 913                    room_id: room_id.to_proto(),
 914                },
 915            )
 916            .trace_err();
 917    }
 918
 919    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 920        if let Some(token) = live_kit
 921            .room_token(&room.live_kit_room, &session.user_id.to_string())
 922            .trace_err()
 923        {
 924            Some(proto::LiveKitConnectionInfo {
 925                server_url: live_kit.url().into(),
 926                token,
 927            })
 928        } else {
 929            None
 930        }
 931    } else {
 932        None
 933    };
 934
 935    response.send(proto::JoinRoomResponse {
 936        room: Some(room),
 937        live_kit_connection_info,
 938    })?;
 939
 940    update_user_contacts(session.user_id, &session).await?;
 941    Ok(())
 942}
 943
 944async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 945    leave_room_for_session(&session).await
 946}
 947
 948async fn call(
 949    request: proto::Call,
 950    response: Response<proto::Call>,
 951    session: Session,
 952) -> Result<()> {
 953    let room_id = RoomId::from_proto(request.room_id);
 954    let calling_user_id = session.user_id;
 955    let calling_connection_id = session.connection_id;
 956    let called_user_id = UserId::from_proto(request.called_user_id);
 957    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 958    if !session
 959        .db()
 960        .await
 961        .has_contact(calling_user_id, called_user_id)
 962        .await?
 963    {
 964        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 965    }
 966
 967    let incoming_call = {
 968        let (room, incoming_call) = &mut *session
 969            .db()
 970            .await
 971            .call(
 972                room_id,
 973                calling_user_id,
 974                calling_connection_id,
 975                called_user_id,
 976                initial_project_id,
 977            )
 978            .await?;
 979        room_updated(&room, &session.peer);
 980        mem::take(incoming_call)
 981    };
 982    update_user_contacts(called_user_id, &session).await?;
 983
 984    let mut calls = session
 985        .connection_pool()
 986        .await
 987        .user_connection_ids(called_user_id)
 988        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 989        .collect::<FuturesUnordered<_>>();
 990
 991    while let Some(call_response) = calls.next().await {
 992        match call_response.as_ref() {
 993            Ok(_) => {
 994                response.send(proto::Ack {})?;
 995                return Ok(());
 996            }
 997            Err(_) => {
 998                call_response.trace_err();
 999            }
1000        }
1001    }
1002
1003    {
1004        let room = session
1005            .db()
1006            .await
1007            .call_failed(room_id, called_user_id)
1008            .await?;
1009        room_updated(&room, &session.peer);
1010    }
1011    update_user_contacts(called_user_id, &session).await?;
1012
1013    Err(anyhow!("failed to ring user"))?
1014}
1015
1016async fn cancel_call(
1017    request: proto::CancelCall,
1018    response: Response<proto::CancelCall>,
1019    session: Session,
1020) -> Result<()> {
1021    let called_user_id = UserId::from_proto(request.called_user_id);
1022    let room_id = RoomId::from_proto(request.room_id);
1023    {
1024        let room = session
1025            .db()
1026            .await
1027            .cancel_call(room_id, session.connection_id, called_user_id)
1028            .await?;
1029        room_updated(&room, &session.peer);
1030    }
1031
1032    for connection_id in session
1033        .connection_pool()
1034        .await
1035        .user_connection_ids(called_user_id)
1036    {
1037        session
1038            .peer
1039            .send(
1040                connection_id,
1041                proto::CallCanceled {
1042                    room_id: room_id.to_proto(),
1043                },
1044            )
1045            .trace_err();
1046    }
1047    response.send(proto::Ack {})?;
1048
1049    update_user_contacts(called_user_id, &session).await?;
1050    Ok(())
1051}
1052
1053async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1054    let room_id = RoomId::from_proto(message.room_id);
1055    {
1056        let room = session
1057            .db()
1058            .await
1059            .decline_call(Some(room_id), session.user_id)
1060            .await?
1061            .ok_or_else(|| anyhow!("failed to decline call"))?;
1062        room_updated(&room, &session.peer);
1063    }
1064
1065    for connection_id in session
1066        .connection_pool()
1067        .await
1068        .user_connection_ids(session.user_id)
1069    {
1070        session
1071            .peer
1072            .send(
1073                connection_id,
1074                proto::CallCanceled {
1075                    room_id: room_id.to_proto(),
1076                },
1077            )
1078            .trace_err();
1079    }
1080    update_user_contacts(session.user_id, &session).await?;
1081    Ok(())
1082}
1083
1084async fn update_participant_location(
1085    request: proto::UpdateParticipantLocation,
1086    response: Response<proto::UpdateParticipantLocation>,
1087    session: Session,
1088) -> Result<()> {
1089    let room_id = RoomId::from_proto(request.room_id);
1090    let location = request
1091        .location
1092        .ok_or_else(|| anyhow!("invalid location"))?;
1093    let room = session
1094        .db()
1095        .await
1096        .update_room_participant_location(room_id, session.connection_id, location)
1097        .await?;
1098    room_updated(&room, &session.peer);
1099    response.send(proto::Ack {})?;
1100    Ok(())
1101}
1102
1103async fn share_project(
1104    request: proto::ShareProject,
1105    response: Response<proto::ShareProject>,
1106    session: Session,
1107) -> Result<()> {
1108    let (project_id, room) = &*session
1109        .db()
1110        .await
1111        .share_project(
1112            RoomId::from_proto(request.room_id),
1113            session.connection_id,
1114            &request.worktrees,
1115        )
1116        .await?;
1117    response.send(proto::ShareProjectResponse {
1118        project_id: project_id.to_proto(),
1119    })?;
1120    room_updated(&room, &session.peer);
1121
1122    Ok(())
1123}
1124
1125async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1126    let project_id = ProjectId::from_proto(message.project_id);
1127
1128    let (room, guest_connection_ids) = &*session
1129        .db()
1130        .await
1131        .unshare_project(project_id, session.connection_id)
1132        .await?;
1133
1134    broadcast(
1135        session.connection_id,
1136        guest_connection_ids.iter().copied(),
1137        |conn_id| session.peer.send(conn_id, message.clone()),
1138    );
1139    room_updated(&room, &session.peer);
1140
1141    Ok(())
1142}
1143
1144async fn join_project(
1145    request: proto::JoinProject,
1146    response: Response<proto::JoinProject>,
1147    session: Session,
1148) -> Result<()> {
1149    let project_id = ProjectId::from_proto(request.project_id);
1150    let guest_user_id = session.user_id;
1151
1152    tracing::info!(%project_id, "join project");
1153
1154    let (project, replica_id) = &mut *session
1155        .db()
1156        .await
1157        .join_project(project_id, session.connection_id)
1158        .await?;
1159
1160    let collaborators = project
1161        .collaborators
1162        .iter()
1163        .map(|collaborator| {
1164            let peer_id = proto::PeerId {
1165                owner_id: collaborator.connection_server_id.0 as u32,
1166                id: collaborator.connection_id as u32,
1167            };
1168            proto::Collaborator {
1169                peer_id: Some(peer_id),
1170                replica_id: collaborator.replica_id.0 as u32,
1171                user_id: collaborator.user_id.to_proto(),
1172            }
1173        })
1174        .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1175        .collect::<Vec<_>>();
1176    let worktrees = project
1177        .worktrees
1178        .iter()
1179        .map(|(id, worktree)| proto::WorktreeMetadata {
1180            id: *id,
1181            root_name: worktree.root_name.clone(),
1182            visible: worktree.visible,
1183            abs_path: worktree.abs_path.clone(),
1184        })
1185        .collect::<Vec<_>>();
1186
1187    for collaborator in &collaborators {
1188        session
1189            .peer
1190            .send(
1191                collaborator.peer_id.unwrap().into(),
1192                proto::AddProjectCollaborator {
1193                    project_id: project_id.to_proto(),
1194                    collaborator: Some(proto::Collaborator {
1195                        peer_id: Some(session.connection_id.into()),
1196                        replica_id: replica_id.0 as u32,
1197                        user_id: guest_user_id.to_proto(),
1198                    }),
1199                },
1200            )
1201            .trace_err();
1202    }
1203
1204    // First, we send the metadata associated with each worktree.
1205    response.send(proto::JoinProjectResponse {
1206        worktrees: worktrees.clone(),
1207        replica_id: replica_id.0 as u32,
1208        collaborators: collaborators.clone(),
1209        language_servers: project.language_servers.clone(),
1210    })?;
1211
1212    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1213        #[cfg(any(test, feature = "test-support"))]
1214        const MAX_CHUNK_SIZE: usize = 2;
1215        #[cfg(not(any(test, feature = "test-support")))]
1216        const MAX_CHUNK_SIZE: usize = 256;
1217
1218        // Stream this worktree's entries.
1219        let message = proto::UpdateWorktree {
1220            project_id: project_id.to_proto(),
1221            worktree_id,
1222            abs_path: worktree.abs_path.clone(),
1223            root_name: worktree.root_name,
1224            updated_entries: worktree.entries,
1225            removed_entries: Default::default(),
1226            scan_id: worktree.scan_id,
1227            is_last_update: worktree.is_complete,
1228        };
1229        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1230            session.peer.send(session.connection_id, update.clone())?;
1231        }
1232
1233        // Stream this worktree's diagnostics.
1234        for summary in worktree.diagnostic_summaries {
1235            session.peer.send(
1236                session.connection_id,
1237                proto::UpdateDiagnosticSummary {
1238                    project_id: project_id.to_proto(),
1239                    worktree_id: worktree.id,
1240                    summary: Some(summary),
1241                },
1242            )?;
1243        }
1244    }
1245
1246    for language_server in &project.language_servers {
1247        session.peer.send(
1248            session.connection_id,
1249            proto::UpdateLanguageServer {
1250                project_id: project_id.to_proto(),
1251                language_server_id: language_server.id,
1252                variant: Some(
1253                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1254                        proto::LspDiskBasedDiagnosticsUpdated {},
1255                    ),
1256                ),
1257            },
1258        )?;
1259    }
1260
1261    Ok(())
1262}
1263
1264async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1265    let sender_id = session.connection_id;
1266    let project_id = ProjectId::from_proto(request.project_id);
1267
1268    let project = session
1269        .db()
1270        .await
1271        .leave_project(project_id, sender_id)
1272        .await?;
1273    tracing::info!(
1274        %project_id,
1275        host_user_id = %project.host_user_id,
1276        host_connection_id = %project.host_connection_id,
1277        "leave project"
1278    );
1279    project_left(&project, &session);
1280
1281    Ok(())
1282}
1283
1284async fn update_project(
1285    request: proto::UpdateProject,
1286    response: Response<proto::UpdateProject>,
1287    session: Session,
1288) -> Result<()> {
1289    let project_id = ProjectId::from_proto(request.project_id);
1290    let (room, guest_connection_ids) = &*session
1291        .db()
1292        .await
1293        .update_project(project_id, session.connection_id, &request.worktrees)
1294        .await?;
1295    broadcast(
1296        session.connection_id,
1297        guest_connection_ids.iter().copied(),
1298        |connection_id| {
1299            session
1300                .peer
1301                .forward_send(session.connection_id, connection_id, request.clone())
1302        },
1303    );
1304    room_updated(&room, &session.peer);
1305    response.send(proto::Ack {})?;
1306
1307    Ok(())
1308}
1309
1310async fn update_worktree(
1311    request: proto::UpdateWorktree,
1312    response: Response<proto::UpdateWorktree>,
1313    session: Session,
1314) -> Result<()> {
1315    let guest_connection_ids = session
1316        .db()
1317        .await
1318        .update_worktree(&request, session.connection_id)
1319        .await?;
1320
1321    broadcast(
1322        session.connection_id,
1323        guest_connection_ids.iter().copied(),
1324        |connection_id| {
1325            session
1326                .peer
1327                .forward_send(session.connection_id, connection_id, request.clone())
1328        },
1329    );
1330    response.send(proto::Ack {})?;
1331    Ok(())
1332}
1333
1334async fn update_diagnostic_summary(
1335    message: proto::UpdateDiagnosticSummary,
1336    session: Session,
1337) -> Result<()> {
1338    let guest_connection_ids = session
1339        .db()
1340        .await
1341        .update_diagnostic_summary(&message, session.connection_id)
1342        .await?;
1343
1344    broadcast(
1345        session.connection_id,
1346        guest_connection_ids.iter().copied(),
1347        |connection_id| {
1348            session
1349                .peer
1350                .forward_send(session.connection_id, connection_id, message.clone())
1351        },
1352    );
1353
1354    Ok(())
1355}
1356
1357async fn start_language_server(
1358    request: proto::StartLanguageServer,
1359    session: Session,
1360) -> Result<()> {
1361    let guest_connection_ids = session
1362        .db()
1363        .await
1364        .start_language_server(&request, session.connection_id)
1365        .await?;
1366
1367    broadcast(
1368        session.connection_id,
1369        guest_connection_ids.iter().copied(),
1370        |connection_id| {
1371            session
1372                .peer
1373                .forward_send(session.connection_id, connection_id, request.clone())
1374        },
1375    );
1376    Ok(())
1377}
1378
1379async fn update_language_server(
1380    request: proto::UpdateLanguageServer,
1381    session: Session,
1382) -> Result<()> {
1383    let project_id = ProjectId::from_proto(request.project_id);
1384    let project_connection_ids = session
1385        .db()
1386        .await
1387        .project_connection_ids(project_id, session.connection_id)
1388        .await?;
1389    broadcast(
1390        session.connection_id,
1391        project_connection_ids.iter().copied(),
1392        |connection_id| {
1393            session
1394                .peer
1395                .forward_send(session.connection_id, connection_id, request.clone())
1396        },
1397    );
1398    Ok(())
1399}
1400
1401async fn forward_project_request<T>(
1402    request: T,
1403    response: Response<T>,
1404    session: Session,
1405) -> Result<()>
1406where
1407    T: EntityMessage + RequestMessage,
1408{
1409    let project_id = ProjectId::from_proto(request.remote_entity_id());
1410    let host_connection_id = {
1411        let collaborators = session
1412            .db()
1413            .await
1414            .project_collaborators(project_id, session.connection_id)
1415            .await?;
1416        let host = collaborators
1417            .iter()
1418            .find(|collaborator| collaborator.is_host)
1419            .ok_or_else(|| anyhow!("host not found"))?;
1420        ConnectionId {
1421            owner_id: host.connection_server_id.0 as u32,
1422            id: host.connection_id as u32,
1423        }
1424    };
1425
1426    let payload = session
1427        .peer
1428        .forward_request(session.connection_id, host_connection_id, request)
1429        .await?;
1430
1431    response.send(payload)?;
1432    Ok(())
1433}
1434
1435async fn save_buffer(
1436    request: proto::SaveBuffer,
1437    response: Response<proto::SaveBuffer>,
1438    session: Session,
1439) -> Result<()> {
1440    let project_id = ProjectId::from_proto(request.project_id);
1441    let host_connection_id = {
1442        let collaborators = session
1443            .db()
1444            .await
1445            .project_collaborators(project_id, session.connection_id)
1446            .await?;
1447        let host = collaborators
1448            .iter()
1449            .find(|collaborator| collaborator.is_host)
1450            .ok_or_else(|| anyhow!("host not found"))?;
1451        ConnectionId {
1452            owner_id: host.connection_server_id.0 as u32,
1453            id: host.connection_id as u32,
1454        }
1455    };
1456    let response_payload = session
1457        .peer
1458        .forward_request(session.connection_id, host_connection_id, request.clone())
1459        .await?;
1460
1461    let mut collaborators = session
1462        .db()
1463        .await
1464        .project_collaborators(project_id, session.connection_id)
1465        .await?;
1466    collaborators.retain(|collaborator| {
1467        let collaborator_connection = ConnectionId {
1468            owner_id: collaborator.connection_server_id.0 as u32,
1469            id: collaborator.connection_id as u32,
1470        };
1471        collaborator_connection != session.connection_id
1472    });
1473    let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1474        owner_id: collaborator.connection_server_id.0 as u32,
1475        id: collaborator.connection_id as u32,
1476    });
1477    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1478        session
1479            .peer
1480            .forward_send(host_connection_id, conn_id, response_payload.clone())
1481    });
1482    response.send(response_payload)?;
1483    Ok(())
1484}
1485
1486async fn create_buffer_for_peer(
1487    request: proto::CreateBufferForPeer,
1488    session: Session,
1489) -> Result<()> {
1490    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1491    session
1492        .peer
1493        .forward_send(session.connection_id, peer_id.into(), request)?;
1494    Ok(())
1495}
1496
1497async fn update_buffer(
1498    request: proto::UpdateBuffer,
1499    response: Response<proto::UpdateBuffer>,
1500    session: Session,
1501) -> Result<()> {
1502    let project_id = ProjectId::from_proto(request.project_id);
1503    let project_connection_ids = session
1504        .db()
1505        .await
1506        .project_connection_ids(project_id, session.connection_id)
1507        .await?;
1508
1509    broadcast(
1510        session.connection_id,
1511        project_connection_ids.iter().copied(),
1512        |connection_id| {
1513            session
1514                .peer
1515                .forward_send(session.connection_id, connection_id, request.clone())
1516        },
1517    );
1518    response.send(proto::Ack {})?;
1519    Ok(())
1520}
1521
1522async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1523    let project_id = ProjectId::from_proto(request.project_id);
1524    let project_connection_ids = session
1525        .db()
1526        .await
1527        .project_connection_ids(project_id, session.connection_id)
1528        .await?;
1529
1530    broadcast(
1531        session.connection_id,
1532        project_connection_ids.iter().copied(),
1533        |connection_id| {
1534            session
1535                .peer
1536                .forward_send(session.connection_id, connection_id, request.clone())
1537        },
1538    );
1539    Ok(())
1540}
1541
1542async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1543    let project_id = ProjectId::from_proto(request.project_id);
1544    let project_connection_ids = session
1545        .db()
1546        .await
1547        .project_connection_ids(project_id, session.connection_id)
1548        .await?;
1549    broadcast(
1550        session.connection_id,
1551        project_connection_ids.iter().copied(),
1552        |connection_id| {
1553            session
1554                .peer
1555                .forward_send(session.connection_id, connection_id, request.clone())
1556        },
1557    );
1558    Ok(())
1559}
1560
1561async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1562    let project_id = ProjectId::from_proto(request.project_id);
1563    let project_connection_ids = session
1564        .db()
1565        .await
1566        .project_connection_ids(project_id, session.connection_id)
1567        .await?;
1568    broadcast(
1569        session.connection_id,
1570        project_connection_ids.iter().copied(),
1571        |connection_id| {
1572            session
1573                .peer
1574                .forward_send(session.connection_id, connection_id, request.clone())
1575        },
1576    );
1577    Ok(())
1578}
1579
1580async fn follow(
1581    request: proto::Follow,
1582    response: Response<proto::Follow>,
1583    session: Session,
1584) -> Result<()> {
1585    let project_id = ProjectId::from_proto(request.project_id);
1586    let leader_id = request
1587        .leader_id
1588        .ok_or_else(|| anyhow!("invalid leader id"))?
1589        .into();
1590    let follower_id = session.connection_id;
1591    {
1592        let project_connection_ids = session
1593            .db()
1594            .await
1595            .project_connection_ids(project_id, session.connection_id)
1596            .await?;
1597
1598        if !project_connection_ids.contains(&leader_id) {
1599            Err(anyhow!("no such peer"))?;
1600        }
1601    }
1602
1603    let mut response_payload = session
1604        .peer
1605        .forward_request(session.connection_id, leader_id, request)
1606        .await?;
1607    response_payload
1608        .views
1609        .retain(|view| view.leader_id != Some(follower_id.into()));
1610    response.send(response_payload)?;
1611    Ok(())
1612}
1613
1614async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1615    let project_id = ProjectId::from_proto(request.project_id);
1616    let leader_id = request
1617        .leader_id
1618        .ok_or_else(|| anyhow!("invalid leader id"))?
1619        .into();
1620    let project_connection_ids = session
1621        .db()
1622        .await
1623        .project_connection_ids(project_id, session.connection_id)
1624        .await?;
1625    if !project_connection_ids.contains(&leader_id) {
1626        Err(anyhow!("no such peer"))?;
1627    }
1628    session
1629        .peer
1630        .forward_send(session.connection_id, leader_id, request)?;
1631    Ok(())
1632}
1633
1634async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1635    let project_id = ProjectId::from_proto(request.project_id);
1636    let project_connection_ids = session
1637        .db
1638        .lock()
1639        .await
1640        .project_connection_ids(project_id, session.connection_id)
1641        .await?;
1642
1643    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1644        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1645        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1646        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1647    });
1648    for follower_peer_id in request.follower_ids.iter().copied() {
1649        let follower_connection_id = follower_peer_id.into();
1650        if project_connection_ids.contains(&follower_connection_id)
1651            && Some(follower_peer_id) != leader_id
1652        {
1653            session.peer.forward_send(
1654                session.connection_id,
1655                follower_connection_id,
1656                request.clone(),
1657            )?;
1658        }
1659    }
1660    Ok(())
1661}
1662
1663async fn get_users(
1664    request: proto::GetUsers,
1665    response: Response<proto::GetUsers>,
1666    session: Session,
1667) -> Result<()> {
1668    let user_ids = request
1669        .user_ids
1670        .into_iter()
1671        .map(UserId::from_proto)
1672        .collect();
1673    let users = session
1674        .db()
1675        .await
1676        .get_users_by_ids(user_ids)
1677        .await?
1678        .into_iter()
1679        .map(|user| proto::User {
1680            id: user.id.to_proto(),
1681            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1682            github_login: user.github_login,
1683        })
1684        .collect();
1685    response.send(proto::UsersResponse { users })?;
1686    Ok(())
1687}
1688
1689async fn fuzzy_search_users(
1690    request: proto::FuzzySearchUsers,
1691    response: Response<proto::FuzzySearchUsers>,
1692    session: Session,
1693) -> Result<()> {
1694    let query = request.query;
1695    let users = match query.len() {
1696        0 => vec![],
1697        1 | 2 => session
1698            .db()
1699            .await
1700            .get_user_by_github_account(&query, None)
1701            .await?
1702            .into_iter()
1703            .collect(),
1704        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1705    };
1706    let users = users
1707        .into_iter()
1708        .filter(|user| user.id != session.user_id)
1709        .map(|user| proto::User {
1710            id: user.id.to_proto(),
1711            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1712            github_login: user.github_login,
1713        })
1714        .collect();
1715    response.send(proto::UsersResponse { users })?;
1716    Ok(())
1717}
1718
1719async fn request_contact(
1720    request: proto::RequestContact,
1721    response: Response<proto::RequestContact>,
1722    session: Session,
1723) -> Result<()> {
1724    let requester_id = session.user_id;
1725    let responder_id = UserId::from_proto(request.responder_id);
1726    if requester_id == responder_id {
1727        return Err(anyhow!("cannot add yourself as a contact"))?;
1728    }
1729
1730    session
1731        .db()
1732        .await
1733        .send_contact_request(requester_id, responder_id)
1734        .await?;
1735
1736    // Update outgoing contact requests of requester
1737    let mut update = proto::UpdateContacts::default();
1738    update.outgoing_requests.push(responder_id.to_proto());
1739    for connection_id in session
1740        .connection_pool()
1741        .await
1742        .user_connection_ids(requester_id)
1743    {
1744        session.peer.send(connection_id, update.clone())?;
1745    }
1746
1747    // Update incoming contact requests of responder
1748    let mut update = proto::UpdateContacts::default();
1749    update
1750        .incoming_requests
1751        .push(proto::IncomingContactRequest {
1752            requester_id: requester_id.to_proto(),
1753            should_notify: true,
1754        });
1755    for connection_id in session
1756        .connection_pool()
1757        .await
1758        .user_connection_ids(responder_id)
1759    {
1760        session.peer.send(connection_id, update.clone())?;
1761    }
1762
1763    response.send(proto::Ack {})?;
1764    Ok(())
1765}
1766
1767async fn respond_to_contact_request(
1768    request: proto::RespondToContactRequest,
1769    response: Response<proto::RespondToContactRequest>,
1770    session: Session,
1771) -> Result<()> {
1772    let responder_id = session.user_id;
1773    let requester_id = UserId::from_proto(request.requester_id);
1774    let db = session.db().await;
1775    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1776        db.dismiss_contact_notification(responder_id, requester_id)
1777            .await?;
1778    } else {
1779        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1780
1781        db.respond_to_contact_request(responder_id, requester_id, accept)
1782            .await?;
1783        let requester_busy = db.is_user_busy(requester_id).await?;
1784        let responder_busy = db.is_user_busy(responder_id).await?;
1785
1786        let pool = session.connection_pool().await;
1787        // Update responder with new contact
1788        let mut update = proto::UpdateContacts::default();
1789        if accept {
1790            update
1791                .contacts
1792                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1793        }
1794        update
1795            .remove_incoming_requests
1796            .push(requester_id.to_proto());
1797        for connection_id in pool.user_connection_ids(responder_id) {
1798            session.peer.send(connection_id, update.clone())?;
1799        }
1800
1801        // Update requester with new contact
1802        let mut update = proto::UpdateContacts::default();
1803        if accept {
1804            update
1805                .contacts
1806                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1807        }
1808        update
1809            .remove_outgoing_requests
1810            .push(responder_id.to_proto());
1811        for connection_id in pool.user_connection_ids(requester_id) {
1812            session.peer.send(connection_id, update.clone())?;
1813        }
1814    }
1815
1816    response.send(proto::Ack {})?;
1817    Ok(())
1818}
1819
1820async fn remove_contact(
1821    request: proto::RemoveContact,
1822    response: Response<proto::RemoveContact>,
1823    session: Session,
1824) -> Result<()> {
1825    let requester_id = session.user_id;
1826    let responder_id = UserId::from_proto(request.user_id);
1827    let db = session.db().await;
1828    db.remove_contact(requester_id, responder_id).await?;
1829
1830    let pool = session.connection_pool().await;
1831    // Update outgoing contact requests of requester
1832    let mut update = proto::UpdateContacts::default();
1833    update
1834        .remove_outgoing_requests
1835        .push(responder_id.to_proto());
1836    for connection_id in pool.user_connection_ids(requester_id) {
1837        session.peer.send(connection_id, update.clone())?;
1838    }
1839
1840    // Update incoming contact requests of responder
1841    let mut update = proto::UpdateContacts::default();
1842    update
1843        .remove_incoming_requests
1844        .push(requester_id.to_proto());
1845    for connection_id in pool.user_connection_ids(responder_id) {
1846        session.peer.send(connection_id, update.clone())?;
1847    }
1848
1849    response.send(proto::Ack {})?;
1850    Ok(())
1851}
1852
1853async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1854    let project_id = ProjectId::from_proto(request.project_id);
1855    let project_connection_ids = session
1856        .db()
1857        .await
1858        .project_connection_ids(project_id, session.connection_id)
1859        .await?;
1860    broadcast(
1861        session.connection_id,
1862        project_connection_ids.iter().copied(),
1863        |connection_id| {
1864            session
1865                .peer
1866                .forward_send(session.connection_id, connection_id, request.clone())
1867        },
1868    );
1869    Ok(())
1870}
1871
1872async fn get_private_user_info(
1873    _request: proto::GetPrivateUserInfo,
1874    response: Response<proto::GetPrivateUserInfo>,
1875    session: Session,
1876) -> Result<()> {
1877    let metrics_id = session
1878        .db()
1879        .await
1880        .get_user_metrics_id(session.user_id)
1881        .await?;
1882    let user = session
1883        .db()
1884        .await
1885        .get_user_by_id(session.user_id)
1886        .await?
1887        .ok_or_else(|| anyhow!("user not found"))?;
1888    response.send(proto::GetPrivateUserInfoResponse {
1889        metrics_id,
1890        staff: user.admin,
1891    })?;
1892    Ok(())
1893}
1894
1895fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1896    match message {
1897        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1898        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1899        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1900        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1901        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1902            code: frame.code.into(),
1903            reason: frame.reason,
1904        })),
1905    }
1906}
1907
1908fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1909    match message {
1910        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1911        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1912        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1913        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1914        AxumMessage::Close(frame) => {
1915            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1916                code: frame.code.into(),
1917                reason: frame.reason,
1918            }))
1919        }
1920    }
1921}
1922
1923fn build_initial_contacts_update(
1924    contacts: Vec<db::Contact>,
1925    pool: &ConnectionPool,
1926) -> proto::UpdateContacts {
1927    let mut update = proto::UpdateContacts::default();
1928
1929    for contact in contacts {
1930        match contact {
1931            db::Contact::Accepted {
1932                user_id,
1933                should_notify,
1934                busy,
1935            } => {
1936                update
1937                    .contacts
1938                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1939            }
1940            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1941            db::Contact::Incoming {
1942                user_id,
1943                should_notify,
1944            } => update
1945                .incoming_requests
1946                .push(proto::IncomingContactRequest {
1947                    requester_id: user_id.to_proto(),
1948                    should_notify,
1949                }),
1950        }
1951    }
1952
1953    update
1954}
1955
1956fn contact_for_user(
1957    user_id: UserId,
1958    should_notify: bool,
1959    busy: bool,
1960    pool: &ConnectionPool,
1961) -> proto::Contact {
1962    proto::Contact {
1963        user_id: user_id.to_proto(),
1964        online: pool.is_user_online(user_id),
1965        busy,
1966        should_notify,
1967    }
1968}
1969
1970fn room_updated(room: &proto::Room, peer: &Peer) {
1971    for participant in &room.participants {
1972        if let Some(peer_id) = participant
1973            .peer_id
1974            .ok_or_else(|| anyhow!("invalid participant peer id"))
1975            .trace_err()
1976        {
1977            peer.send(
1978                peer_id.into(),
1979                proto::RoomUpdated {
1980                    room: Some(room.clone()),
1981                },
1982            )
1983            .trace_err();
1984        }
1985    }
1986}
1987
1988async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1989    let db = session.db().await;
1990    let contacts = db.get_contacts(user_id).await?;
1991    let busy = db.is_user_busy(user_id).await?;
1992
1993    let pool = session.connection_pool().await;
1994    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1995    for contact in contacts {
1996        if let db::Contact::Accepted {
1997            user_id: contact_user_id,
1998            ..
1999        } = contact
2000        {
2001            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2002                session
2003                    .peer
2004                    .send(
2005                        contact_conn_id,
2006                        proto::UpdateContacts {
2007                            contacts: vec![updated_contact.clone()],
2008                            remove_contacts: Default::default(),
2009                            incoming_requests: Default::default(),
2010                            remove_incoming_requests: Default::default(),
2011                            outgoing_requests: Default::default(),
2012                            remove_outgoing_requests: Default::default(),
2013                        },
2014                    )
2015                    .trace_err();
2016            }
2017        }
2018    }
2019    Ok(())
2020}
2021
2022async fn leave_room_for_session(session: &Session) -> Result<()> {
2023    let mut contacts_to_update = HashSet::default();
2024
2025    let room_id;
2026    let canceled_calls_to_user_ids;
2027    let live_kit_room;
2028    let delete_live_kit_room;
2029    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2030        contacts_to_update.insert(session.user_id);
2031
2032        for project in left_room.left_projects.values() {
2033            project_left(project, session);
2034        }
2035
2036        room_updated(&left_room.room, &session.peer);
2037        room_id = RoomId::from_proto(left_room.room.id);
2038        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2039        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2040        delete_live_kit_room = left_room.room.participants.is_empty();
2041    } else {
2042        return Ok(());
2043    }
2044
2045    {
2046        let pool = session.connection_pool().await;
2047        for canceled_user_id in canceled_calls_to_user_ids {
2048            for connection_id in pool.user_connection_ids(canceled_user_id) {
2049                session
2050                    .peer
2051                    .send(
2052                        connection_id,
2053                        proto::CallCanceled {
2054                            room_id: room_id.to_proto(),
2055                        },
2056                    )
2057                    .trace_err();
2058            }
2059            contacts_to_update.insert(canceled_user_id);
2060        }
2061    }
2062
2063    for contact_user_id in contacts_to_update {
2064        update_user_contacts(contact_user_id, &session).await?;
2065    }
2066
2067    if let Some(live_kit) = session.live_kit_client.as_ref() {
2068        live_kit
2069            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2070            .await
2071            .trace_err();
2072
2073        if delete_live_kit_room {
2074            live_kit.delete_room(live_kit_room).await.trace_err();
2075        }
2076    }
2077
2078    Ok(())
2079}
2080
2081fn project_left(project: &db::LeftProject, session: &Session) {
2082    for connection_id in &project.connection_ids {
2083        if project.host_user_id == session.user_id {
2084            session
2085                .peer
2086                .send(
2087                    *connection_id,
2088                    proto::UnshareProject {
2089                        project_id: project.id.to_proto(),
2090                    },
2091                )
2092                .trace_err();
2093        } else {
2094            session
2095                .peer
2096                .send(
2097                    *connection_id,
2098                    proto::RemoveProjectCollaborator {
2099                        project_id: project.id.to_proto(),
2100                        peer_id: Some(session.connection_id.into()),
2101                    },
2102                )
2103                .trace_err();
2104        }
2105    }
2106
2107    session
2108        .peer
2109        .send(
2110            session.connection_id,
2111            proto::UnshareProject {
2112                project_id: project.id.to_proto(),
2113            },
2114        )
2115        .trace_err();
2116}
2117
2118pub trait ResultExt {
2119    type Ok;
2120
2121    fn trace_err(self) -> Option<Self::Ok>;
2122}
2123
2124impl<T, E> ResultExt for Result<T, E>
2125where
2126    E: std::fmt::Debug,
2127{
2128    type Ok = T;
2129
2130    fn trace_err(self) -> Option<T> {
2131        match self {
2132            Ok(value) => Some(value),
2133            Err(error) => {
2134                tracing::error!("{:?}", error);
2135                None
2136            }
2137        }
2138    }
2139}