rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98}
  99
 100impl Session {
 101    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 102        #[cfg(test)]
 103        tokio::task::yield_now().await;
 104        let guard = self.db.lock().await;
 105        #[cfg(test)]
 106        tokio::task::yield_now().await;
 107        guard
 108    }
 109
 110    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 111        #[cfg(test)]
 112        tokio::task::yield_now().await;
 113        let guard = self.connection_pool.lock();
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<Database>);
 131
 132impl Deref for DbHandle {
 133    type Target = Database;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    id: parking_lot::Mutex<ServerId>,
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    executor: Executor,
 146    handlers: HashMap<TypeId, MessageHandler>,
 147    teardown: watch::Sender<()>,
 148}
 149
 150pub(crate) struct ConnectionPoolGuard<'a> {
 151    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 152    _not_send: PhantomData<Rc<()>>,
 153}
 154
 155#[derive(Serialize)]
 156pub struct ServerSnapshot<'a> {
 157    peer: &'a Peer,
 158    #[serde(serialize_with = "serialize_deref")]
 159    connection_pool: ConnectionPoolGuard<'a>,
 160}
 161
 162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 163where
 164    S: Serializer,
 165    T: Deref<Target = U>,
 166    U: Serialize,
 167{
 168    Serialize::serialize(value.deref(), serializer)
 169}
 170
 171impl Server {
 172    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 173        let mut server = Self {
 174            id: parking_lot::Mutex::new(id),
 175            peer: Peer::new(id.0 as u32),
 176            app_state,
 177            executor,
 178            connection_pool: Default::default(),
 179            handlers: Default::default(),
 180            teardown: watch::channel(()).0,
 181        };
 182
 183        server
 184            .add_request_handler(ping)
 185            .add_request_handler(create_room)
 186            .add_request_handler(join_room)
 187            .add_message_handler(leave_room)
 188            .add_request_handler(call)
 189            .add_request_handler(cancel_call)
 190            .add_message_handler(decline_call)
 191            .add_request_handler(update_participant_location)
 192            .add_request_handler(share_project)
 193            .add_message_handler(unshare_project)
 194            .add_request_handler(join_project)
 195            .add_message_handler(leave_project)
 196            .add_request_handler(update_project)
 197            .add_request_handler(update_worktree)
 198            .add_message_handler(start_language_server)
 199            .add_message_handler(update_language_server)
 200            .add_message_handler(update_diagnostic_summary)
 201            .add_request_handler(forward_project_request::<proto::GetHover>)
 202            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 203            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 204            .add_request_handler(forward_project_request::<proto::GetReferences>)
 205            .add_request_handler(forward_project_request::<proto::SearchProject>)
 206            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 207            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 208            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 209            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 210            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 211            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 212            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 213            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 214            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 215            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 216            .add_request_handler(forward_project_request::<proto::PerformRename>)
 217            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 218            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 219            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 220            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 221            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 222            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 223            .add_message_handler(create_buffer_for_peer)
 224            .add_request_handler(update_buffer)
 225            .add_message_handler(update_buffer_file)
 226            .add_message_handler(buffer_reloaded)
 227            .add_message_handler(buffer_saved)
 228            .add_request_handler(save_buffer)
 229            .add_request_handler(get_users)
 230            .add_request_handler(fuzzy_search_users)
 231            .add_request_handler(request_contact)
 232            .add_request_handler(remove_contact)
 233            .add_request_handler(respond_to_contact_request)
 234            .add_request_handler(follow)
 235            .add_message_handler(unfollow)
 236            .add_message_handler(update_followers)
 237            .add_message_handler(update_diff_base)
 238            .add_request_handler(get_private_user_info);
 239
 240        Arc::new(server)
 241    }
 242
 243    pub async fn start(&self) -> Result<()> {
 244        let server_id = *self.id.lock();
 245        let app_state = self.app_state.clone();
 246        let peer = self.peer.clone();
 247        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 248        let pool = self.connection_pool.clone();
 249        let live_kit_client = self.app_state.live_kit_client.clone();
 250
 251        let span = info_span!("start server");
 252        let span_enter = span.enter();
 253
 254        tracing::info!("begin deleting stale projects");
 255        app_state
 256            .db
 257            .delete_stale_projects(&app_state.config.zed_environment, server_id)
 258            .await?;
 259        tracing::info!("finish deleting stale projects");
 260
 261        drop(span_enter);
 262        self.executor.spawn_detached(
 263            async move {
 264                tracing::info!("waiting for cleanup timeout");
 265                timeout.await;
 266                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 267                if let Some(room_ids) = app_state
 268                    .db
 269                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 270                    .await
 271                    .trace_err()
 272                {
 273                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 274                    for room_id in room_ids {
 275                        let mut contacts_to_update = HashSet::default();
 276                        let mut canceled_calls_to_user_ids = Vec::new();
 277                        let mut live_kit_room = String::new();
 278                        let mut delete_live_kit_room = false;
 279
 280                        if let Ok(mut refreshed_room) =
 281                            app_state.db.refresh_room(room_id, server_id).await
 282                        {
 283                            tracing::info!(
 284                                room_id = room_id.0,
 285                                new_participant_count = refreshed_room.room.participants.len(),
 286                                "refreshed room"
 287                            );
 288                            room_updated(&refreshed_room.room, &peer);
 289                            contacts_to_update
 290                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 291                            contacts_to_update
 292                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 293                            canceled_calls_to_user_ids =
 294                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 295                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 296                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 297                        }
 298
 299                        {
 300                            let pool = pool.lock();
 301                            for canceled_user_id in canceled_calls_to_user_ids {
 302                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 303                                    peer.send(
 304                                        connection_id,
 305                                        proto::CallCanceled {
 306                                            room_id: room_id.to_proto(),
 307                                        },
 308                                    )
 309                                    .trace_err();
 310                                }
 311                            }
 312                        }
 313
 314                        for user_id in contacts_to_update {
 315                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 316                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 317                            if let Some((busy, contacts)) = busy.zip(contacts) {
 318                                let pool = pool.lock();
 319                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 320                                for contact in contacts {
 321                                    if let db::Contact::Accepted {
 322                                        user_id: contact_user_id,
 323                                        ..
 324                                    } = contact
 325                                    {
 326                                        for contact_conn_id in
 327                                            pool.user_connection_ids(contact_user_id)
 328                                        {
 329                                            peer.send(
 330                                                contact_conn_id,
 331                                                proto::UpdateContacts {
 332                                                    contacts: vec![updated_contact.clone()],
 333                                                    remove_contacts: Default::default(),
 334                                                    incoming_requests: Default::default(),
 335                                                    remove_incoming_requests: Default::default(),
 336                                                    outgoing_requests: Default::default(),
 337                                                    remove_outgoing_requests: Default::default(),
 338                                                },
 339                                            )
 340                                            .trace_err();
 341                                        }
 342                                    }
 343                                }
 344                            }
 345                        }
 346
 347                        if let Some(live_kit) = live_kit_client.as_ref() {
 348                            if delete_live_kit_room {
 349                                live_kit.delete_room(live_kit_room).await.trace_err();
 350                            }
 351                        }
 352                    }
 353                }
 354
 355                app_state
 356                    .db
 357                    .delete_stale_servers(server_id, &app_state.config.zed_environment)
 358                    .await
 359                    .trace_err();
 360            }
 361            .instrument(span),
 362        );
 363        Ok(())
 364    }
 365
 366    pub fn teardown(&self) {
 367        self.peer.teardown();
 368        self.connection_pool.lock().reset();
 369        let _ = self.teardown.send(());
 370    }
 371
 372    #[cfg(test)]
 373    pub fn reset(&self, id: ServerId) {
 374        self.teardown();
 375        *self.id.lock() = id;
 376        self.peer.reset(id.0 as u32);
 377    }
 378
 379    #[cfg(test)]
 380    pub fn id(&self) -> ServerId {
 381        *self.id.lock()
 382    }
 383
 384    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 385    where
 386        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 387        Fut: 'static + Send + Future<Output = Result<()>>,
 388        M: EnvelopedMessage,
 389    {
 390        let prev_handler = self.handlers.insert(
 391            TypeId::of::<M>(),
 392            Box::new(move |envelope, session| {
 393                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 394                let span = info_span!(
 395                    "handle message",
 396                    payload_type = envelope.payload_type_name()
 397                );
 398                span.in_scope(|| {
 399                    tracing::info!(
 400                        payload_type = envelope.payload_type_name(),
 401                        "message received"
 402                    );
 403                });
 404                let future = (handler)(*envelope, session);
 405                async move {
 406                    if let Err(error) = future.await {
 407                        tracing::error!(%error, "error handling message");
 408                    }
 409                }
 410                .instrument(span)
 411                .boxed()
 412            }),
 413        );
 414        if prev_handler.is_some() {
 415            panic!("registered a handler for the same message twice");
 416        }
 417        self
 418    }
 419
 420    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 421    where
 422        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 423        Fut: 'static + Send + Future<Output = Result<()>>,
 424        M: EnvelopedMessage,
 425    {
 426        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 427        self
 428    }
 429
 430    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 431    where
 432        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 433        Fut: Send + Future<Output = Result<()>>,
 434        M: RequestMessage,
 435    {
 436        let handler = Arc::new(handler);
 437        self.add_handler(move |envelope, session| {
 438            let receipt = envelope.receipt();
 439            let handler = handler.clone();
 440            async move {
 441                let peer = session.peer.clone();
 442                let responded = Arc::new(AtomicBool::default());
 443                let response = Response {
 444                    peer: peer.clone(),
 445                    responded: responded.clone(),
 446                    receipt,
 447                };
 448                match (handler)(envelope.payload, response, session).await {
 449                    Ok(()) => {
 450                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 451                            Ok(())
 452                        } else {
 453                            Err(anyhow!("handler did not send a response"))?
 454                        }
 455                    }
 456                    Err(error) => {
 457                        peer.respond_with_error(
 458                            receipt,
 459                            proto::Error {
 460                                message: error.to_string(),
 461                            },
 462                        )?;
 463                        Err(error)
 464                    }
 465                }
 466            }
 467        })
 468    }
 469
 470    pub fn handle_connection(
 471        self: &Arc<Self>,
 472        connection: Connection,
 473        address: String,
 474        user: User,
 475        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 476        executor: Executor,
 477    ) -> impl Future<Output = Result<()>> {
 478        let this = self.clone();
 479        let user_id = user.id;
 480        let login = user.github_login;
 481        let span = info_span!("handle connection", %user_id, %login, %address);
 482        let mut teardown = self.teardown.subscribe();
 483        async move {
 484            let (connection_id, handle_io, mut incoming_rx) = this
 485                .peer
 486                .add_connection(connection, {
 487                    let executor = executor.clone();
 488                    move |duration| executor.sleep(duration)
 489                });
 490
 491            tracing::info!(%user_id, %login, ?connection_id, %address, "connection opened");
 492            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 493            tracing::info!(%user_id, %login, ?connection_id, %address, "sent hello message");
 494
 495            if let Some(send_connection_id) = send_connection_id.take() {
 496                let _ = send_connection_id.send(connection_id);
 497            }
 498
 499            if !user.connected_once {
 500                this.peer.send(connection_id, proto::ShowContacts {})?;
 501                this.app_state.db.set_user_connected_once(user_id, true).await?;
 502            }
 503
 504            let (contacts, invite_code) = future::try_join(
 505                this.app_state.db.get_contacts(user_id),
 506                this.app_state.db.get_invite_code_for_user(user_id)
 507            ).await?;
 508
 509            {
 510                let mut pool = this.connection_pool.lock();
 511                pool.add_connection(connection_id, user_id, user.admin);
 512                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 513
 514                if let Some((code, count)) = invite_code {
 515                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 516                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 517                        count: count as u32,
 518                    })?;
 519                }
 520            }
 521
 522            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 523                this.peer.send(connection_id, incoming_call)?;
 524            }
 525
 526            let session = Session {
 527                user_id,
 528                connection_id,
 529                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 530                peer: this.peer.clone(),
 531                connection_pool: this.connection_pool.clone(),
 532                live_kit_client: this.app_state.live_kit_client.clone()
 533            };
 534            update_user_contacts(user_id, &session).await?;
 535
 536            let handle_io = handle_io.fuse();
 537            futures::pin_mut!(handle_io);
 538
 539            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 540            // This prevents deadlocks when e.g., client A performs a request to client B and
 541            // client B performs a request to client A. If both clients stop processing further
 542            // messages until their respective request completes, they won't have a chance to
 543            // respond to the other client's request and cause a deadlock.
 544            //
 545            // This arrangement ensures we will attempt to process earlier messages first, but fall
 546            // back to processing messages arrived later in the spirit of making progress.
 547            let mut foreground_message_handlers = FuturesUnordered::new();
 548            loop {
 549                let next_message = incoming_rx.next().fuse();
 550                futures::pin_mut!(next_message);
 551                futures::select_biased! {
 552                    _ = teardown.changed().fuse() => return Ok(()),
 553                    result = handle_io => {
 554                        if let Err(error) = result {
 555                            tracing::error!(?error, %user_id, %login, ?connection_id, %address, "error handling I/O");
 556                        }
 557                        break;
 558                    }
 559                    _ = foreground_message_handlers.next() => {}
 560                    message = next_message => {
 561                        if let Some(message) = message {
 562                            let type_name = message.payload_type_name();
 563                            let span = tracing::info_span!("receive message", %user_id, %login, ?connection_id, %address, type_name);
 564                            let span_enter = span.enter();
 565                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 566                                let is_background = message.is_background();
 567                                let handle_message = (handler)(message, session.clone());
 568                                drop(span_enter);
 569
 570                                let handle_message = handle_message.instrument(span);
 571                                if is_background {
 572                                    executor.spawn_detached(handle_message);
 573                                } else {
 574                                    foreground_message_handlers.push(handle_message);
 575                                }
 576                            } else {
 577                                tracing::error!(%user_id, %login, ?connection_id, %address, "no message handler");
 578                            }
 579                        } else {
 580                            tracing::info!(%user_id, %login, ?connection_id, %address, "connection closed");
 581                            break;
 582                        }
 583                    }
 584                }
 585            }
 586
 587            drop(foreground_message_handlers);
 588            tracing::info!(%user_id, %login, ?connection_id, %address, "signing out");
 589            if let Err(error) = sign_out(session, teardown, executor).await {
 590                tracing::error!(%user_id, %login, ?connection_id, %address, ?error, "error signing out");
 591            }
 592
 593            Ok(())
 594        }.instrument(span)
 595    }
 596
 597    pub async fn invite_code_redeemed(
 598        self: &Arc<Self>,
 599        inviter_id: UserId,
 600        invitee_id: UserId,
 601    ) -> Result<()> {
 602        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 603            if let Some(code) = &user.invite_code {
 604                let pool = self.connection_pool.lock();
 605                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 606                for connection_id in pool.user_connection_ids(inviter_id) {
 607                    self.peer.send(
 608                        connection_id,
 609                        proto::UpdateContacts {
 610                            contacts: vec![invitee_contact.clone()],
 611                            ..Default::default()
 612                        },
 613                    )?;
 614                    self.peer.send(
 615                        connection_id,
 616                        proto::UpdateInviteInfo {
 617                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 618                            count: user.invite_count as u32,
 619                        },
 620                    )?;
 621                }
 622            }
 623        }
 624        Ok(())
 625    }
 626
 627    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 628        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 629            if let Some(invite_code) = &user.invite_code {
 630                let pool = self.connection_pool.lock();
 631                for connection_id in pool.user_connection_ids(user_id) {
 632                    self.peer.send(
 633                        connection_id,
 634                        proto::UpdateInviteInfo {
 635                            url: format!(
 636                                "{}{}",
 637                                self.app_state.config.invite_link_prefix, invite_code
 638                            ),
 639                            count: user.invite_count as u32,
 640                        },
 641                    )?;
 642                }
 643            }
 644        }
 645        Ok(())
 646    }
 647
 648    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 649        ServerSnapshot {
 650            connection_pool: ConnectionPoolGuard {
 651                guard: self.connection_pool.lock(),
 652                _not_send: PhantomData,
 653            },
 654            peer: &self.peer,
 655        }
 656    }
 657}
 658
 659impl<'a> Deref for ConnectionPoolGuard<'a> {
 660    type Target = ConnectionPool;
 661
 662    fn deref(&self) -> &Self::Target {
 663        &*self.guard
 664    }
 665}
 666
 667impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 668    fn deref_mut(&mut self) -> &mut Self::Target {
 669        &mut *self.guard
 670    }
 671}
 672
 673impl<'a> Drop for ConnectionPoolGuard<'a> {
 674    fn drop(&mut self) {
 675        #[cfg(test)]
 676        self.check_invariants();
 677    }
 678}
 679
 680fn broadcast<F>(
 681    sender_id: ConnectionId,
 682    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 683    mut f: F,
 684) where
 685    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 686{
 687    for receiver_id in receiver_ids {
 688        if receiver_id != sender_id {
 689            f(receiver_id).trace_err();
 690        }
 691    }
 692}
 693
 694lazy_static! {
 695    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 696}
 697
 698pub struct ProtocolVersion(u32);
 699
 700impl Header for ProtocolVersion {
 701    fn name() -> &'static HeaderName {
 702        &ZED_PROTOCOL_VERSION
 703    }
 704
 705    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 706    where
 707        Self: Sized,
 708        I: Iterator<Item = &'i axum::http::HeaderValue>,
 709    {
 710        let version = values
 711            .next()
 712            .ok_or_else(axum::headers::Error::invalid)?
 713            .to_str()
 714            .map_err(|_| axum::headers::Error::invalid())?
 715            .parse()
 716            .map_err(|_| axum::headers::Error::invalid())?;
 717        Ok(Self(version))
 718    }
 719
 720    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 721        values.extend([self.0.to_string().parse().unwrap()]);
 722    }
 723}
 724
 725pub fn routes(server: Arc<Server>) -> Router<Body> {
 726    Router::new()
 727        .route("/rpc", get(handle_websocket_request))
 728        .layer(
 729            ServiceBuilder::new()
 730                .layer(Extension(server.app_state.clone()))
 731                .layer(middleware::from_fn(auth::validate_header)),
 732        )
 733        .route("/metrics", get(handle_metrics))
 734        .layer(Extension(server))
 735}
 736
 737pub async fn handle_websocket_request(
 738    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 739    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 740    Extension(server): Extension<Arc<Server>>,
 741    Extension(user): Extension<User>,
 742    ws: WebSocketUpgrade,
 743) -> axum::response::Response {
 744    if protocol_version != rpc::PROTOCOL_VERSION {
 745        return (
 746            StatusCode::UPGRADE_REQUIRED,
 747            "client must be upgraded".to_string(),
 748        )
 749            .into_response();
 750    }
 751    let socket_address = socket_address.to_string();
 752    ws.on_upgrade(move |socket| {
 753        use util::ResultExt;
 754        let socket = socket
 755            .map_ok(to_tungstenite_message)
 756            .err_into()
 757            .with(|message| async move { Ok(to_axum_message(message)) });
 758        let connection = Connection::new(Box::pin(socket));
 759        async move {
 760            server
 761                .handle_connection(connection, socket_address, user, None, Executor::Production)
 762                .await
 763                .log_err();
 764        }
 765    })
 766}
 767
 768pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 769    let connections = server
 770        .connection_pool
 771        .lock()
 772        .connections()
 773        .filter(|connection| !connection.admin)
 774        .count();
 775
 776    METRIC_CONNECTIONS.set(connections as _);
 777
 778    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 779    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 780
 781    let encoder = prometheus::TextEncoder::new();
 782    let metric_families = prometheus::gather();
 783    let encoded_metrics = encoder
 784        .encode_to_string(&metric_families)
 785        .map_err(|err| anyhow!("{}", err))?;
 786    Ok(encoded_metrics)
 787}
 788
 789#[instrument(err, skip(executor))]
 790async fn sign_out(
 791    session: Session,
 792    mut teardown: watch::Receiver<()>,
 793    executor: Executor,
 794) -> Result<()> {
 795    session.peer.disconnect(session.connection_id);
 796    session
 797        .connection_pool()
 798        .await
 799        .remove_connection(session.connection_id)?;
 800
 801    if let Some(mut left_projects) = session
 802        .db()
 803        .await
 804        .connection_lost(session.connection_id)
 805        .await
 806        .trace_err()
 807    {
 808        for left_project in mem::take(&mut *left_projects) {
 809            project_left(&left_project, &session);
 810        }
 811    }
 812
 813    futures::select_biased! {
 814        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 815            leave_room_for_session(&session).await.trace_err();
 816
 817            if !session
 818                .connection_pool()
 819                .await
 820                .is_user_online(session.user_id)
 821            {
 822                let db = session.db().await;
 823                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 824                    room_updated(&room, &session.peer);
 825                }
 826            }
 827            update_user_contacts(session.user_id, &session).await?;
 828        }
 829        _ = teardown.changed().fuse() => {}
 830    }
 831
 832    Ok(())
 833}
 834
 835async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 836    response.send(proto::Ack {})?;
 837    Ok(())
 838}
 839
 840async fn create_room(
 841    _request: proto::CreateRoom,
 842    response: Response<proto::CreateRoom>,
 843    session: Session,
 844) -> Result<()> {
 845    let live_kit_room = nanoid::nanoid!(30);
 846    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 847        if let Some(_) = live_kit
 848            .create_room(live_kit_room.clone())
 849            .await
 850            .trace_err()
 851        {
 852            if let Some(token) = live_kit
 853                .room_token(&live_kit_room, &session.connection_id.to_string())
 854                .trace_err()
 855            {
 856                Some(proto::LiveKitConnectionInfo {
 857                    server_url: live_kit.url().into(),
 858                    token,
 859                })
 860            } else {
 861                None
 862            }
 863        } else {
 864            None
 865        }
 866    } else {
 867        None
 868    };
 869
 870    {
 871        let room = session
 872            .db()
 873            .await
 874            .create_room(session.user_id, session.connection_id, &live_kit_room)
 875            .await?;
 876
 877        response.send(proto::CreateRoomResponse {
 878            room: Some(room.clone()),
 879            live_kit_connection_info,
 880        })?;
 881    }
 882
 883    update_user_contacts(session.user_id, &session).await?;
 884    Ok(())
 885}
 886
 887async fn join_room(
 888    request: proto::JoinRoom,
 889    response: Response<proto::JoinRoom>,
 890    session: Session,
 891) -> Result<()> {
 892    let room_id = RoomId::from_proto(request.id);
 893    let room = {
 894        let room = session
 895            .db()
 896            .await
 897            .join_room(room_id, session.user_id, session.connection_id)
 898            .await?;
 899        room_updated(&room, &session.peer);
 900        room.clone()
 901    };
 902
 903    for connection_id in session
 904        .connection_pool()
 905        .await
 906        .user_connection_ids(session.user_id)
 907    {
 908        session
 909            .peer
 910            .send(
 911                connection_id,
 912                proto::CallCanceled {
 913                    room_id: room_id.to_proto(),
 914                },
 915            )
 916            .trace_err();
 917    }
 918
 919    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 920        if let Some(token) = live_kit
 921            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 922            .trace_err()
 923        {
 924            Some(proto::LiveKitConnectionInfo {
 925                server_url: live_kit.url().into(),
 926                token,
 927            })
 928        } else {
 929            None
 930        }
 931    } else {
 932        None
 933    };
 934
 935    response.send(proto::JoinRoomResponse {
 936        room: Some(room),
 937        live_kit_connection_info,
 938    })?;
 939
 940    update_user_contacts(session.user_id, &session).await?;
 941    Ok(())
 942}
 943
 944async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 945    leave_room_for_session(&session).await
 946}
 947
 948async fn call(
 949    request: proto::Call,
 950    response: Response<proto::Call>,
 951    session: Session,
 952) -> Result<()> {
 953    let room_id = RoomId::from_proto(request.room_id);
 954    let calling_user_id = session.user_id;
 955    let calling_connection_id = session.connection_id;
 956    let called_user_id = UserId::from_proto(request.called_user_id);
 957    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 958    if !session
 959        .db()
 960        .await
 961        .has_contact(calling_user_id, called_user_id)
 962        .await?
 963    {
 964        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 965    }
 966
 967    let incoming_call = {
 968        let (room, incoming_call) = &mut *session
 969            .db()
 970            .await
 971            .call(
 972                room_id,
 973                calling_user_id,
 974                calling_connection_id,
 975                called_user_id,
 976                initial_project_id,
 977            )
 978            .await?;
 979        room_updated(&room, &session.peer);
 980        mem::take(incoming_call)
 981    };
 982    update_user_contacts(called_user_id, &session).await?;
 983
 984    let mut calls = session
 985        .connection_pool()
 986        .await
 987        .user_connection_ids(called_user_id)
 988        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 989        .collect::<FuturesUnordered<_>>();
 990
 991    while let Some(call_response) = calls.next().await {
 992        match call_response.as_ref() {
 993            Ok(_) => {
 994                response.send(proto::Ack {})?;
 995                return Ok(());
 996            }
 997            Err(_) => {
 998                call_response.trace_err();
 999            }
1000        }
1001    }
1002
1003    {
1004        let room = session
1005            .db()
1006            .await
1007            .call_failed(room_id, called_user_id)
1008            .await?;
1009        room_updated(&room, &session.peer);
1010    }
1011    update_user_contacts(called_user_id, &session).await?;
1012
1013    Err(anyhow!("failed to ring user"))?
1014}
1015
1016async fn cancel_call(
1017    request: proto::CancelCall,
1018    response: Response<proto::CancelCall>,
1019    session: Session,
1020) -> Result<()> {
1021    let called_user_id = UserId::from_proto(request.called_user_id);
1022    let room_id = RoomId::from_proto(request.room_id);
1023    {
1024        let room = session
1025            .db()
1026            .await
1027            .cancel_call(room_id, session.connection_id, called_user_id)
1028            .await?;
1029        room_updated(&room, &session.peer);
1030    }
1031
1032    for connection_id in session
1033        .connection_pool()
1034        .await
1035        .user_connection_ids(called_user_id)
1036    {
1037        session
1038            .peer
1039            .send(
1040                connection_id,
1041                proto::CallCanceled {
1042                    room_id: room_id.to_proto(),
1043                },
1044            )
1045            .trace_err();
1046    }
1047    response.send(proto::Ack {})?;
1048
1049    update_user_contacts(called_user_id, &session).await?;
1050    Ok(())
1051}
1052
1053async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1054    let room_id = RoomId::from_proto(message.room_id);
1055    {
1056        let room = session
1057            .db()
1058            .await
1059            .decline_call(Some(room_id), session.user_id)
1060            .await?
1061            .ok_or_else(|| anyhow!("failed to decline call"))?;
1062        room_updated(&room, &session.peer);
1063    }
1064
1065    for connection_id in session
1066        .connection_pool()
1067        .await
1068        .user_connection_ids(session.user_id)
1069    {
1070        session
1071            .peer
1072            .send(
1073                connection_id,
1074                proto::CallCanceled {
1075                    room_id: room_id.to_proto(),
1076                },
1077            )
1078            .trace_err();
1079    }
1080    update_user_contacts(session.user_id, &session).await?;
1081    Ok(())
1082}
1083
1084async fn update_participant_location(
1085    request: proto::UpdateParticipantLocation,
1086    response: Response<proto::UpdateParticipantLocation>,
1087    session: Session,
1088) -> Result<()> {
1089    let room_id = RoomId::from_proto(request.room_id);
1090    let location = request
1091        .location
1092        .ok_or_else(|| anyhow!("invalid location"))?;
1093    let room = session
1094        .db()
1095        .await
1096        .update_room_participant_location(room_id, session.connection_id, location)
1097        .await?;
1098    room_updated(&room, &session.peer);
1099    response.send(proto::Ack {})?;
1100    Ok(())
1101}
1102
1103async fn share_project(
1104    request: proto::ShareProject,
1105    response: Response<proto::ShareProject>,
1106    session: Session,
1107) -> Result<()> {
1108    let (project_id, room) = &*session
1109        .db()
1110        .await
1111        .share_project(
1112            RoomId::from_proto(request.room_id),
1113            session.connection_id,
1114            &request.worktrees,
1115        )
1116        .await?;
1117    response.send(proto::ShareProjectResponse {
1118        project_id: project_id.to_proto(),
1119    })?;
1120    room_updated(&room, &session.peer);
1121
1122    Ok(())
1123}
1124
1125async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1126    let project_id = ProjectId::from_proto(message.project_id);
1127
1128    let (room, guest_connection_ids) = &*session
1129        .db()
1130        .await
1131        .unshare_project(project_id, session.connection_id)
1132        .await?;
1133
1134    broadcast(
1135        session.connection_id,
1136        guest_connection_ids.iter().copied(),
1137        |conn_id| session.peer.send(conn_id, message.clone()),
1138    );
1139    room_updated(&room, &session.peer);
1140
1141    Ok(())
1142}
1143
1144async fn join_project(
1145    request: proto::JoinProject,
1146    response: Response<proto::JoinProject>,
1147    session: Session,
1148) -> Result<()> {
1149    let project_id = ProjectId::from_proto(request.project_id);
1150    let guest_user_id = session.user_id;
1151
1152    tracing::info!(%project_id, "join project");
1153
1154    let (project, replica_id) = &mut *session
1155        .db()
1156        .await
1157        .join_project(project_id, session.connection_id)
1158        .await?;
1159
1160    let collaborators = project
1161        .collaborators
1162        .iter()
1163        .map(|collaborator| {
1164            let peer_id = proto::PeerId {
1165                owner_id: collaborator.connection_server_id.0 as u32,
1166                id: collaborator.connection_id as u32,
1167            };
1168            proto::Collaborator {
1169                peer_id: Some(peer_id),
1170                replica_id: collaborator.replica_id.0 as u32,
1171                user_id: collaborator.user_id.to_proto(),
1172            }
1173        })
1174        .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1175        .collect::<Vec<_>>();
1176    let worktrees = project
1177        .worktrees
1178        .iter()
1179        .map(|(id, worktree)| proto::WorktreeMetadata {
1180            id: *id,
1181            root_name: worktree.root_name.clone(),
1182            visible: worktree.visible,
1183            abs_path: worktree.abs_path.clone(),
1184        })
1185        .collect::<Vec<_>>();
1186
1187    for collaborator in &collaborators {
1188        session
1189            .peer
1190            .send(
1191                collaborator.peer_id.unwrap().into(),
1192                proto::AddProjectCollaborator {
1193                    project_id: project_id.to_proto(),
1194                    collaborator: Some(proto::Collaborator {
1195                        peer_id: Some(session.connection_id.into()),
1196                        replica_id: replica_id.0 as u32,
1197                        user_id: guest_user_id.to_proto(),
1198                    }),
1199                },
1200            )
1201            .trace_err();
1202    }
1203
1204    // First, we send the metadata associated with each worktree.
1205    response.send(proto::JoinProjectResponse {
1206        worktrees: worktrees.clone(),
1207        replica_id: replica_id.0 as u32,
1208        collaborators: collaborators.clone(),
1209        language_servers: project.language_servers.clone(),
1210    })?;
1211
1212    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1213        #[cfg(any(test, feature = "test-support"))]
1214        const MAX_CHUNK_SIZE: usize = 2;
1215        #[cfg(not(any(test, feature = "test-support")))]
1216        const MAX_CHUNK_SIZE: usize = 256;
1217
1218        // Stream this worktree's entries.
1219        let message = proto::UpdateWorktree {
1220            project_id: project_id.to_proto(),
1221            worktree_id,
1222            abs_path: worktree.abs_path.clone(),
1223            root_name: worktree.root_name,
1224            updated_entries: worktree.entries,
1225            removed_entries: Default::default(),
1226            scan_id: worktree.scan_id,
1227            is_last_update: worktree.is_complete,
1228        };
1229        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1230            session.peer.send(session.connection_id, update.clone())?;
1231        }
1232
1233        // Stream this worktree's diagnostics.
1234        for summary in worktree.diagnostic_summaries {
1235            session.peer.send(
1236                session.connection_id,
1237                proto::UpdateDiagnosticSummary {
1238                    project_id: project_id.to_proto(),
1239                    worktree_id: worktree.id,
1240                    summary: Some(summary),
1241                },
1242            )?;
1243        }
1244    }
1245
1246    for language_server in &project.language_servers {
1247        session.peer.send(
1248            session.connection_id,
1249            proto::UpdateLanguageServer {
1250                project_id: project_id.to_proto(),
1251                language_server_id: language_server.id,
1252                variant: Some(
1253                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1254                        proto::LspDiskBasedDiagnosticsUpdated {},
1255                    ),
1256                ),
1257            },
1258        )?;
1259    }
1260
1261    Ok(())
1262}
1263
1264async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1265    let sender_id = session.connection_id;
1266    let project_id = ProjectId::from_proto(request.project_id);
1267
1268    let project = session
1269        .db()
1270        .await
1271        .leave_project(project_id, sender_id)
1272        .await?;
1273    tracing::info!(
1274        %project_id,
1275        host_user_id = %project.host_user_id,
1276        host_connection_id = %project.host_connection_id,
1277        "leave project"
1278    );
1279    project_left(&project, &session);
1280
1281    Ok(())
1282}
1283
1284async fn update_project(
1285    request: proto::UpdateProject,
1286    response: Response<proto::UpdateProject>,
1287    session: Session,
1288) -> Result<()> {
1289    let project_id = ProjectId::from_proto(request.project_id);
1290    let (room, guest_connection_ids) = &*session
1291        .db()
1292        .await
1293        .update_project(project_id, session.connection_id, &request.worktrees)
1294        .await?;
1295    broadcast(
1296        session.connection_id,
1297        guest_connection_ids.iter().copied(),
1298        |connection_id| {
1299            session
1300                .peer
1301                .forward_send(session.connection_id, connection_id, request.clone())
1302        },
1303    );
1304    room_updated(&room, &session.peer);
1305    response.send(proto::Ack {})?;
1306
1307    Ok(())
1308}
1309
1310async fn update_worktree(
1311    request: proto::UpdateWorktree,
1312    response: Response<proto::UpdateWorktree>,
1313    session: Session,
1314) -> Result<()> {
1315    let guest_connection_ids = session
1316        .db()
1317        .await
1318        .update_worktree(&request, session.connection_id)
1319        .await?;
1320
1321    broadcast(
1322        session.connection_id,
1323        guest_connection_ids.iter().copied(),
1324        |connection_id| {
1325            session
1326                .peer
1327                .forward_send(session.connection_id, connection_id, request.clone())
1328        },
1329    );
1330    response.send(proto::Ack {})?;
1331    Ok(())
1332}
1333
1334async fn update_diagnostic_summary(
1335    message: proto::UpdateDiagnosticSummary,
1336    session: Session,
1337) -> Result<()> {
1338    let guest_connection_ids = session
1339        .db()
1340        .await
1341        .update_diagnostic_summary(&message, session.connection_id)
1342        .await?;
1343
1344    broadcast(
1345        session.connection_id,
1346        guest_connection_ids.iter().copied(),
1347        |connection_id| {
1348            session
1349                .peer
1350                .forward_send(session.connection_id, connection_id, message.clone())
1351        },
1352    );
1353
1354    Ok(())
1355}
1356
1357async fn start_language_server(
1358    request: proto::StartLanguageServer,
1359    session: Session,
1360) -> Result<()> {
1361    let guest_connection_ids = session
1362        .db()
1363        .await
1364        .start_language_server(&request, session.connection_id)
1365        .await?;
1366
1367    broadcast(
1368        session.connection_id,
1369        guest_connection_ids.iter().copied(),
1370        |connection_id| {
1371            session
1372                .peer
1373                .forward_send(session.connection_id, connection_id, request.clone())
1374        },
1375    );
1376    Ok(())
1377}
1378
1379async fn update_language_server(
1380    request: proto::UpdateLanguageServer,
1381    session: Session,
1382) -> Result<()> {
1383    let project_id = ProjectId::from_proto(request.project_id);
1384    let project_connection_ids = session
1385        .db()
1386        .await
1387        .project_connection_ids(project_id, session.connection_id)
1388        .await?;
1389    broadcast(
1390        session.connection_id,
1391        project_connection_ids.iter().copied(),
1392        |connection_id| {
1393            session
1394                .peer
1395                .forward_send(session.connection_id, connection_id, request.clone())
1396        },
1397    );
1398    Ok(())
1399}
1400
1401async fn forward_project_request<T>(
1402    request: T,
1403    response: Response<T>,
1404    session: Session,
1405) -> Result<()>
1406where
1407    T: EntityMessage + RequestMessage,
1408{
1409    let project_id = ProjectId::from_proto(request.remote_entity_id());
1410    let host_connection_id = {
1411        let collaborators = session
1412            .db()
1413            .await
1414            .project_collaborators(project_id, session.connection_id)
1415            .await?;
1416        let host = collaborators
1417            .iter()
1418            .find(|collaborator| collaborator.is_host)
1419            .ok_or_else(|| anyhow!("host not found"))?;
1420        ConnectionId {
1421            owner_id: host.connection_server_id.0 as u32,
1422            id: host.connection_id as u32,
1423        }
1424    };
1425
1426    let payload = session
1427        .peer
1428        .forward_request(session.connection_id, host_connection_id, request)
1429        .await?;
1430
1431    response.send(payload)?;
1432    Ok(())
1433}
1434
1435async fn save_buffer(
1436    request: proto::SaveBuffer,
1437    response: Response<proto::SaveBuffer>,
1438    session: Session,
1439) -> Result<()> {
1440    let project_id = ProjectId::from_proto(request.project_id);
1441    let host_connection_id = {
1442        let collaborators = session
1443            .db()
1444            .await
1445            .project_collaborators(project_id, session.connection_id)
1446            .await?;
1447        let host = collaborators
1448            .iter()
1449            .find(|collaborator| collaborator.is_host)
1450            .ok_or_else(|| anyhow!("host not found"))?;
1451        ConnectionId {
1452            owner_id: host.connection_server_id.0 as u32,
1453            id: host.connection_id as u32,
1454        }
1455    };
1456    let response_payload = session
1457        .peer
1458        .forward_request(session.connection_id, host_connection_id, request.clone())
1459        .await?;
1460
1461    let mut collaborators = session
1462        .db()
1463        .await
1464        .project_collaborators(project_id, session.connection_id)
1465        .await?;
1466    collaborators.retain(|collaborator| {
1467        let collaborator_connection = ConnectionId {
1468            owner_id: collaborator.connection_server_id.0 as u32,
1469            id: collaborator.connection_id as u32,
1470        };
1471        collaborator_connection != session.connection_id
1472    });
1473    let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1474        owner_id: collaborator.connection_server_id.0 as u32,
1475        id: collaborator.connection_id as u32,
1476    });
1477    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1478        session
1479            .peer
1480            .forward_send(host_connection_id, conn_id, response_payload.clone())
1481    });
1482    response.send(response_payload)?;
1483    Ok(())
1484}
1485
1486async fn create_buffer_for_peer(
1487    request: proto::CreateBufferForPeer,
1488    session: Session,
1489) -> Result<()> {
1490    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1491    session
1492        .peer
1493        .forward_send(session.connection_id, peer_id.into(), request)?;
1494    Ok(())
1495}
1496
1497async fn update_buffer(
1498    request: proto::UpdateBuffer,
1499    response: Response<proto::UpdateBuffer>,
1500    session: Session,
1501) -> Result<()> {
1502    let project_id = ProjectId::from_proto(request.project_id);
1503    let project_connection_ids = session
1504        .db()
1505        .await
1506        .project_connection_ids(project_id, session.connection_id)
1507        .await?;
1508
1509    dbg!(session.connection_id, &*project_connection_ids);
1510    broadcast(
1511        session.connection_id,
1512        project_connection_ids.iter().copied(),
1513        |connection_id| {
1514            session
1515                .peer
1516                .forward_send(session.connection_id, connection_id, request.clone())
1517        },
1518    );
1519    response.send(proto::Ack {})?;
1520    Ok(())
1521}
1522
1523async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1524    let project_id = ProjectId::from_proto(request.project_id);
1525    let project_connection_ids = session
1526        .db()
1527        .await
1528        .project_connection_ids(project_id, session.connection_id)
1529        .await?;
1530
1531    broadcast(
1532        session.connection_id,
1533        project_connection_ids.iter().copied(),
1534        |connection_id| {
1535            session
1536                .peer
1537                .forward_send(session.connection_id, connection_id, request.clone())
1538        },
1539    );
1540    Ok(())
1541}
1542
1543async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1544    let project_id = ProjectId::from_proto(request.project_id);
1545    let project_connection_ids = session
1546        .db()
1547        .await
1548        .project_connection_ids(project_id, session.connection_id)
1549        .await?;
1550    broadcast(
1551        session.connection_id,
1552        project_connection_ids.iter().copied(),
1553        |connection_id| {
1554            session
1555                .peer
1556                .forward_send(session.connection_id, connection_id, request.clone())
1557        },
1558    );
1559    Ok(())
1560}
1561
1562async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1563    let project_id = ProjectId::from_proto(request.project_id);
1564    let project_connection_ids = session
1565        .db()
1566        .await
1567        .project_connection_ids(project_id, session.connection_id)
1568        .await?;
1569    broadcast(
1570        session.connection_id,
1571        project_connection_ids.iter().copied(),
1572        |connection_id| {
1573            session
1574                .peer
1575                .forward_send(session.connection_id, connection_id, request.clone())
1576        },
1577    );
1578    Ok(())
1579}
1580
1581async fn follow(
1582    request: proto::Follow,
1583    response: Response<proto::Follow>,
1584    session: Session,
1585) -> Result<()> {
1586    let project_id = ProjectId::from_proto(request.project_id);
1587    let leader_id = request
1588        .leader_id
1589        .ok_or_else(|| anyhow!("invalid leader id"))?
1590        .into();
1591    let follower_id = session.connection_id;
1592    {
1593        let project_connection_ids = session
1594            .db()
1595            .await
1596            .project_connection_ids(project_id, session.connection_id)
1597            .await?;
1598
1599        if !project_connection_ids.contains(&leader_id) {
1600            Err(anyhow!("no such peer"))?;
1601        }
1602    }
1603
1604    let mut response_payload = session
1605        .peer
1606        .forward_request(session.connection_id, leader_id, request)
1607        .await?;
1608    response_payload
1609        .views
1610        .retain(|view| view.leader_id != Some(follower_id.into()));
1611    response.send(response_payload)?;
1612    Ok(())
1613}
1614
1615async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1616    let project_id = ProjectId::from_proto(request.project_id);
1617    let leader_id = request
1618        .leader_id
1619        .ok_or_else(|| anyhow!("invalid leader id"))?
1620        .into();
1621    let project_connection_ids = session
1622        .db()
1623        .await
1624        .project_connection_ids(project_id, session.connection_id)
1625        .await?;
1626    if !project_connection_ids.contains(&leader_id) {
1627        Err(anyhow!("no such peer"))?;
1628    }
1629    session
1630        .peer
1631        .forward_send(session.connection_id, leader_id, request)?;
1632    Ok(())
1633}
1634
1635async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1636    let project_id = ProjectId::from_proto(request.project_id);
1637    let project_connection_ids = session
1638        .db
1639        .lock()
1640        .await
1641        .project_connection_ids(project_id, session.connection_id)
1642        .await?;
1643
1644    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1645        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1646        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1647        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1648    });
1649    for follower_peer_id in request.follower_ids.iter().copied() {
1650        let follower_connection_id = follower_peer_id.into();
1651        if project_connection_ids.contains(&follower_connection_id)
1652            && Some(follower_peer_id) != leader_id
1653        {
1654            session.peer.forward_send(
1655                session.connection_id,
1656                follower_connection_id,
1657                request.clone(),
1658            )?;
1659        }
1660    }
1661    Ok(())
1662}
1663
1664async fn get_users(
1665    request: proto::GetUsers,
1666    response: Response<proto::GetUsers>,
1667    session: Session,
1668) -> Result<()> {
1669    let user_ids = request
1670        .user_ids
1671        .into_iter()
1672        .map(UserId::from_proto)
1673        .collect();
1674    let users = session
1675        .db()
1676        .await
1677        .get_users_by_ids(user_ids)
1678        .await?
1679        .into_iter()
1680        .map(|user| proto::User {
1681            id: user.id.to_proto(),
1682            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1683            github_login: user.github_login,
1684        })
1685        .collect();
1686    response.send(proto::UsersResponse { users })?;
1687    Ok(())
1688}
1689
1690async fn fuzzy_search_users(
1691    request: proto::FuzzySearchUsers,
1692    response: Response<proto::FuzzySearchUsers>,
1693    session: Session,
1694) -> Result<()> {
1695    let query = request.query;
1696    let users = match query.len() {
1697        0 => vec![],
1698        1 | 2 => session
1699            .db()
1700            .await
1701            .get_user_by_github_account(&query, None)
1702            .await?
1703            .into_iter()
1704            .collect(),
1705        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1706    };
1707    let users = users
1708        .into_iter()
1709        .filter(|user| user.id != session.user_id)
1710        .map(|user| proto::User {
1711            id: user.id.to_proto(),
1712            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1713            github_login: user.github_login,
1714        })
1715        .collect();
1716    response.send(proto::UsersResponse { users })?;
1717    Ok(())
1718}
1719
1720async fn request_contact(
1721    request: proto::RequestContact,
1722    response: Response<proto::RequestContact>,
1723    session: Session,
1724) -> Result<()> {
1725    let requester_id = session.user_id;
1726    let responder_id = UserId::from_proto(request.responder_id);
1727    if requester_id == responder_id {
1728        return Err(anyhow!("cannot add yourself as a contact"))?;
1729    }
1730
1731    session
1732        .db()
1733        .await
1734        .send_contact_request(requester_id, responder_id)
1735        .await?;
1736
1737    // Update outgoing contact requests of requester
1738    let mut update = proto::UpdateContacts::default();
1739    update.outgoing_requests.push(responder_id.to_proto());
1740    for connection_id in session
1741        .connection_pool()
1742        .await
1743        .user_connection_ids(requester_id)
1744    {
1745        session.peer.send(connection_id, update.clone())?;
1746    }
1747
1748    // Update incoming contact requests of responder
1749    let mut update = proto::UpdateContacts::default();
1750    update
1751        .incoming_requests
1752        .push(proto::IncomingContactRequest {
1753            requester_id: requester_id.to_proto(),
1754            should_notify: true,
1755        });
1756    for connection_id in session
1757        .connection_pool()
1758        .await
1759        .user_connection_ids(responder_id)
1760    {
1761        session.peer.send(connection_id, update.clone())?;
1762    }
1763
1764    response.send(proto::Ack {})?;
1765    Ok(())
1766}
1767
1768async fn respond_to_contact_request(
1769    request: proto::RespondToContactRequest,
1770    response: Response<proto::RespondToContactRequest>,
1771    session: Session,
1772) -> Result<()> {
1773    let responder_id = session.user_id;
1774    let requester_id = UserId::from_proto(request.requester_id);
1775    let db = session.db().await;
1776    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1777        db.dismiss_contact_notification(responder_id, requester_id)
1778            .await?;
1779    } else {
1780        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1781
1782        db.respond_to_contact_request(responder_id, requester_id, accept)
1783            .await?;
1784        let requester_busy = db.is_user_busy(requester_id).await?;
1785        let responder_busy = db.is_user_busy(responder_id).await?;
1786
1787        let pool = session.connection_pool().await;
1788        // Update responder with new contact
1789        let mut update = proto::UpdateContacts::default();
1790        if accept {
1791            update
1792                .contacts
1793                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1794        }
1795        update
1796            .remove_incoming_requests
1797            .push(requester_id.to_proto());
1798        for connection_id in pool.user_connection_ids(responder_id) {
1799            session.peer.send(connection_id, update.clone())?;
1800        }
1801
1802        // Update requester with new contact
1803        let mut update = proto::UpdateContacts::default();
1804        if accept {
1805            update
1806                .contacts
1807                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1808        }
1809        update
1810            .remove_outgoing_requests
1811            .push(responder_id.to_proto());
1812        for connection_id in pool.user_connection_ids(requester_id) {
1813            session.peer.send(connection_id, update.clone())?;
1814        }
1815    }
1816
1817    response.send(proto::Ack {})?;
1818    Ok(())
1819}
1820
1821async fn remove_contact(
1822    request: proto::RemoveContact,
1823    response: Response<proto::RemoveContact>,
1824    session: Session,
1825) -> Result<()> {
1826    let requester_id = session.user_id;
1827    let responder_id = UserId::from_proto(request.user_id);
1828    let db = session.db().await;
1829    db.remove_contact(requester_id, responder_id).await?;
1830
1831    let pool = session.connection_pool().await;
1832    // Update outgoing contact requests of requester
1833    let mut update = proto::UpdateContacts::default();
1834    update
1835        .remove_outgoing_requests
1836        .push(responder_id.to_proto());
1837    for connection_id in pool.user_connection_ids(requester_id) {
1838        session.peer.send(connection_id, update.clone())?;
1839    }
1840
1841    // Update incoming contact requests of responder
1842    let mut update = proto::UpdateContacts::default();
1843    update
1844        .remove_incoming_requests
1845        .push(requester_id.to_proto());
1846    for connection_id in pool.user_connection_ids(responder_id) {
1847        session.peer.send(connection_id, update.clone())?;
1848    }
1849
1850    response.send(proto::Ack {})?;
1851    Ok(())
1852}
1853
1854async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1855    let project_id = ProjectId::from_proto(request.project_id);
1856    let project_connection_ids = session
1857        .db()
1858        .await
1859        .project_connection_ids(project_id, session.connection_id)
1860        .await?;
1861    broadcast(
1862        session.connection_id,
1863        project_connection_ids.iter().copied(),
1864        |connection_id| {
1865            session
1866                .peer
1867                .forward_send(session.connection_id, connection_id, request.clone())
1868        },
1869    );
1870    Ok(())
1871}
1872
1873async fn get_private_user_info(
1874    _request: proto::GetPrivateUserInfo,
1875    response: Response<proto::GetPrivateUserInfo>,
1876    session: Session,
1877) -> Result<()> {
1878    let metrics_id = session
1879        .db()
1880        .await
1881        .get_user_metrics_id(session.user_id)
1882        .await?;
1883    let user = session
1884        .db()
1885        .await
1886        .get_user_by_id(session.user_id)
1887        .await?
1888        .ok_or_else(|| anyhow!("user not found"))?;
1889    response.send(proto::GetPrivateUserInfoResponse {
1890        metrics_id,
1891        staff: user.admin,
1892    })?;
1893    Ok(())
1894}
1895
1896fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1897    match message {
1898        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1899        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1900        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1901        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1902        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1903            code: frame.code.into(),
1904            reason: frame.reason,
1905        })),
1906    }
1907}
1908
1909fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1910    match message {
1911        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1912        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1913        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1914        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1915        AxumMessage::Close(frame) => {
1916            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1917                code: frame.code.into(),
1918                reason: frame.reason,
1919            }))
1920        }
1921    }
1922}
1923
1924fn build_initial_contacts_update(
1925    contacts: Vec<db::Contact>,
1926    pool: &ConnectionPool,
1927) -> proto::UpdateContacts {
1928    let mut update = proto::UpdateContacts::default();
1929
1930    for contact in contacts {
1931        match contact {
1932            db::Contact::Accepted {
1933                user_id,
1934                should_notify,
1935                busy,
1936            } => {
1937                update
1938                    .contacts
1939                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1940            }
1941            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1942            db::Contact::Incoming {
1943                user_id,
1944                should_notify,
1945            } => update
1946                .incoming_requests
1947                .push(proto::IncomingContactRequest {
1948                    requester_id: user_id.to_proto(),
1949                    should_notify,
1950                }),
1951        }
1952    }
1953
1954    update
1955}
1956
1957fn contact_for_user(
1958    user_id: UserId,
1959    should_notify: bool,
1960    busy: bool,
1961    pool: &ConnectionPool,
1962) -> proto::Contact {
1963    proto::Contact {
1964        user_id: user_id.to_proto(),
1965        online: pool.is_user_online(user_id),
1966        busy,
1967        should_notify,
1968    }
1969}
1970
1971fn room_updated(room: &proto::Room, peer: &Peer) {
1972    for participant in &room.participants {
1973        if let Some(peer_id) = participant
1974            .peer_id
1975            .ok_or_else(|| anyhow!("invalid participant peer id"))
1976            .trace_err()
1977        {
1978            peer.send(
1979                peer_id.into(),
1980                proto::RoomUpdated {
1981                    room: Some(room.clone()),
1982                },
1983            )
1984            .trace_err();
1985        }
1986    }
1987}
1988
1989async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1990    let db = session.db().await;
1991    let contacts = db.get_contacts(user_id).await?;
1992    let busy = db.is_user_busy(user_id).await?;
1993
1994    let pool = session.connection_pool().await;
1995    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1996    for contact in contacts {
1997        if let db::Contact::Accepted {
1998            user_id: contact_user_id,
1999            ..
2000        } = contact
2001        {
2002            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2003                session
2004                    .peer
2005                    .send(
2006                        contact_conn_id,
2007                        proto::UpdateContacts {
2008                            contacts: vec![updated_contact.clone()],
2009                            remove_contacts: Default::default(),
2010                            incoming_requests: Default::default(),
2011                            remove_incoming_requests: Default::default(),
2012                            outgoing_requests: Default::default(),
2013                            remove_outgoing_requests: Default::default(),
2014                        },
2015                    )
2016                    .trace_err();
2017            }
2018        }
2019    }
2020    Ok(())
2021}
2022
2023async fn leave_room_for_session(session: &Session) -> Result<()> {
2024    let mut contacts_to_update = HashSet::default();
2025
2026    let room_id;
2027    let canceled_calls_to_user_ids;
2028    let live_kit_room;
2029    let delete_live_kit_room;
2030    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2031        contacts_to_update.insert(session.user_id);
2032
2033        for project in left_room.left_projects.values() {
2034            project_left(project, session);
2035        }
2036
2037        room_updated(&left_room.room, &session.peer);
2038        room_id = RoomId::from_proto(left_room.room.id);
2039        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2040        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2041        delete_live_kit_room = left_room.room.participants.is_empty();
2042    } else {
2043        return Ok(());
2044    }
2045
2046    {
2047        let pool = session.connection_pool().await;
2048        for canceled_user_id in canceled_calls_to_user_ids {
2049            for connection_id in pool.user_connection_ids(canceled_user_id) {
2050                session
2051                    .peer
2052                    .send(
2053                        connection_id,
2054                        proto::CallCanceled {
2055                            room_id: room_id.to_proto(),
2056                        },
2057                    )
2058                    .trace_err();
2059            }
2060            contacts_to_update.insert(canceled_user_id);
2061        }
2062    }
2063
2064    for contact_user_id in contacts_to_update {
2065        update_user_contacts(contact_user_id, &session).await?;
2066    }
2067
2068    if let Some(live_kit) = session.live_kit_client.as_ref() {
2069        live_kit
2070            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2071            .await
2072            .trace_err();
2073
2074        if delete_live_kit_room {
2075            live_kit.delete_room(live_kit_room).await.trace_err();
2076        }
2077    }
2078
2079    Ok(())
2080}
2081
2082fn project_left(project: &db::LeftProject, session: &Session) {
2083    for connection_id in &project.connection_ids {
2084        if project.host_user_id == session.user_id {
2085            session
2086                .peer
2087                .send(
2088                    *connection_id,
2089                    proto::UnshareProject {
2090                        project_id: project.id.to_proto(),
2091                    },
2092                )
2093                .trace_err();
2094        } else {
2095            session
2096                .peer
2097                .send(
2098                    *connection_id,
2099                    proto::RemoveProjectCollaborator {
2100                        project_id: project.id.to_proto(),
2101                        peer_id: Some(session.connection_id.into()),
2102                    },
2103                )
2104                .trace_err();
2105        }
2106    }
2107
2108    session
2109        .peer
2110        .send(
2111            session.connection_id,
2112            proto::UnshareProject {
2113                project_id: project.id.to_proto(),
2114            },
2115        )
2116        .trace_err();
2117}
2118
2119pub trait ResultExt {
2120    type Ok;
2121
2122    fn trace_err(self) -> Option<Self::Ok>;
2123}
2124
2125impl<T, E> ResultExt for Result<T, E>
2126where
2127    E: std::fmt::Debug,
2128{
2129    type Ok = T;
2130
2131    fn trace_err(self) -> Option<T> {
2132        match self {
2133            Ok(value) => Some(value),
2134            Err(error) => {
2135                tracing::error!("{:?}", error);
2136                None
2137            }
2138        }
2139    }
2140}