rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::{Duration, Instant},
  55};
  56use tokio::sync::{watch, Semaphore};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  62
  63lazy_static! {
  64    static ref METRIC_CONNECTIONS: IntGauge =
  65        register_int_gauge!("connections", "number of connections").unwrap();
  66    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  67        "shared_projects",
  68        "number of open projects with one or more guests"
  69    )
  70    .unwrap();
  71}
  72
  73type MessageHandler =
  74    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  75
  76struct Response<R> {
  77    peer: Arc<Peer>,
  78    receipt: Receipt<R>,
  79    responded: Arc<AtomicBool>,
  80}
  81
  82impl<R: RequestMessage> Response<R> {
  83    fn send(self, payload: R::Response) -> Result<()> {
  84        self.responded.store(true, SeqCst);
  85        self.peer.respond(self.receipt, payload)?;
  86        Ok(())
  87    }
  88}
  89
  90#[derive(Clone)]
  91struct Session {
  92    user_id: UserId,
  93    connection_id: ConnectionId,
  94    db: Arc<tokio::sync::Mutex<DbHandle>>,
  95    peer: Arc<Peer>,
  96    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  97    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  98    executor: Executor,
  99}
 100
 101impl Session {
 102    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 103        #[cfg(test)]
 104        tokio::task::yield_now().await;
 105        let guard = self.db.lock().await;
 106        #[cfg(test)]
 107        tokio::task::yield_now().await;
 108        guard
 109    }
 110
 111    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 112        #[cfg(test)]
 113        tokio::task::yield_now().await;
 114        let guard = self.connection_pool.lock();
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    id: parking_lot::Mutex<ServerId>,
 143    peer: Arc<Peer>,
 144    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 145    app_state: Arc<AppState>,
 146    executor: Executor,
 147    handlers: HashMap<TypeId, MessageHandler>,
 148    teardown: watch::Sender<()>,
 149}
 150
 151pub(crate) struct ConnectionPoolGuard<'a> {
 152    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 153    _not_send: PhantomData<Rc<()>>,
 154}
 155
 156#[derive(Serialize)]
 157pub struct ServerSnapshot<'a> {
 158    peer: &'a Peer,
 159    #[serde(serialize_with = "serialize_deref")]
 160    connection_pool: ConnectionPoolGuard<'a>,
 161}
 162
 163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 164where
 165    S: Serializer,
 166    T: Deref<Target = U>,
 167    U: Serialize,
 168{
 169    Serialize::serialize(value.deref(), serializer)
 170}
 171
 172impl Server {
 173    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 174        let mut server = Self {
 175            id: parking_lot::Mutex::new(id),
 176            peer: Peer::new(id.0 as u32),
 177            app_state,
 178            executor,
 179            connection_pool: Default::default(),
 180            handlers: Default::default(),
 181            teardown: watch::channel(()).0,
 182        };
 183
 184        server
 185            .add_request_handler(ping)
 186            .add_request_handler(create_room)
 187            .add_request_handler(join_room)
 188            .add_request_handler(rejoin_room)
 189            .add_request_handler(leave_room)
 190            .add_request_handler(call)
 191            .add_request_handler(cancel_call)
 192            .add_message_handler(decline_call)
 193            .add_request_handler(update_participant_location)
 194            .add_request_handler(share_project)
 195            .add_message_handler(unshare_project)
 196            .add_request_handler(join_project)
 197            .add_message_handler(leave_project)
 198            .add_request_handler(update_project)
 199            .add_request_handler(update_worktree)
 200            .add_message_handler(start_language_server)
 201            .add_message_handler(update_language_server)
 202            .add_message_handler(update_diagnostic_summary)
 203            .add_message_handler(update_worktree_settings)
 204            .add_message_handler(refresh_inlay_hints)
 205            .add_request_handler(forward_project_request::<proto::GetHover>)
 206            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 207            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 208            .add_request_handler(forward_project_request::<proto::GetReferences>)
 209            .add_request_handler(forward_project_request::<proto::SearchProject>)
 210            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 211            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 212            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 213            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 214            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 215            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 216            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 217            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 218            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 219            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 220            .add_request_handler(forward_project_request::<proto::PerformRename>)
 221            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 222            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 223            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 224            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 225            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 226            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 227            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 228            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 229            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 230            .add_request_handler(forward_project_request::<proto::InlayHints>)
 231            .add_message_handler(create_buffer_for_peer)
 232            .add_request_handler(update_buffer)
 233            .add_message_handler(update_buffer_file)
 234            .add_message_handler(buffer_reloaded)
 235            .add_message_handler(buffer_saved)
 236            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 237            .add_request_handler(get_users)
 238            .add_request_handler(fuzzy_search_users)
 239            .add_request_handler(request_contact)
 240            .add_request_handler(remove_contact)
 241            .add_request_handler(respond_to_contact_request)
 242            .add_request_handler(create_channel)
 243            .add_request_handler(invite_channel_member)
 244            .add_request_handler(remove_channel_member)
 245            .add_request_handler(respond_to_channel_invite)
 246            .add_request_handler(follow)
 247            .add_message_handler(unfollow)
 248            .add_message_handler(update_followers)
 249            .add_message_handler(update_diff_base)
 250            .add_request_handler(get_private_user_info);
 251
 252        Arc::new(server)
 253    }
 254
 255    pub async fn start(&self) -> Result<()> {
 256        let server_id = *self.id.lock();
 257        let app_state = self.app_state.clone();
 258        let peer = self.peer.clone();
 259        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 260        let pool = self.connection_pool.clone();
 261        let live_kit_client = self.app_state.live_kit_client.clone();
 262
 263        let span = info_span!("start server");
 264        self.executor.spawn_detached(
 265            async move {
 266                tracing::info!("waiting for cleanup timeout");
 267                timeout.await;
 268                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 269                if let Some(room_ids) = app_state
 270                    .db
 271                    .stale_room_ids(&app_state.config.zed_environment, server_id)
 272                    .await
 273                    .trace_err()
 274                {
 275                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 276                    for room_id in room_ids {
 277                        let mut contacts_to_update = HashSet::default();
 278                        let mut canceled_calls_to_user_ids = Vec::new();
 279                        let mut live_kit_room = String::new();
 280                        let mut delete_live_kit_room = false;
 281
 282                        if let Some(mut refreshed_room) = app_state
 283                            .db
 284                            .refresh_room(room_id, server_id)
 285                            .await
 286                            .trace_err()
 287                        {
 288                            tracing::info!(
 289                                room_id = room_id.0,
 290                                new_participant_count = refreshed_room.room.participants.len(),
 291                                "refreshed room"
 292                            );
 293                            room_updated(&refreshed_room.room, &peer);
 294                            contacts_to_update
 295                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 296                            contacts_to_update
 297                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 298                            canceled_calls_to_user_ids =
 299                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 300                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 301                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 302                        }
 303
 304                        {
 305                            let pool = pool.lock();
 306                            for canceled_user_id in canceled_calls_to_user_ids {
 307                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 308                                    peer.send(
 309                                        connection_id,
 310                                        proto::CallCanceled {
 311                                            room_id: room_id.to_proto(),
 312                                        },
 313                                    )
 314                                    .trace_err();
 315                                }
 316                            }
 317                        }
 318
 319                        for user_id in contacts_to_update {
 320                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 321                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 322                            if let Some((busy, contacts)) = busy.zip(contacts) {
 323                                let pool = pool.lock();
 324                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 325                                for contact in contacts {
 326                                    if let db::Contact::Accepted {
 327                                        user_id: contact_user_id,
 328                                        ..
 329                                    } = contact
 330                                    {
 331                                        for contact_conn_id in
 332                                            pool.user_connection_ids(contact_user_id)
 333                                        {
 334                                            peer.send(
 335                                                contact_conn_id,
 336                                                proto::UpdateContacts {
 337                                                    contacts: vec![updated_contact.clone()],
 338                                                    remove_contacts: Default::default(),
 339                                                    incoming_requests: Default::default(),
 340                                                    remove_incoming_requests: Default::default(),
 341                                                    outgoing_requests: Default::default(),
 342                                                    remove_outgoing_requests: Default::default(),
 343                                                },
 344                                            )
 345                                            .trace_err();
 346                                        }
 347                                    }
 348                                }
 349                            }
 350                        }
 351
 352                        if let Some(live_kit) = live_kit_client.as_ref() {
 353                            if delete_live_kit_room {
 354                                live_kit.delete_room(live_kit_room).await.trace_err();
 355                            }
 356                        }
 357                    }
 358                }
 359
 360                app_state
 361                    .db
 362                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 363                    .await
 364                    .trace_err();
 365            }
 366            .instrument(span),
 367        );
 368        Ok(())
 369    }
 370
 371    pub fn teardown(&self) {
 372        self.peer.teardown();
 373        self.connection_pool.lock().reset();
 374        let _ = self.teardown.send(());
 375    }
 376
 377    #[cfg(test)]
 378    pub fn reset(&self, id: ServerId) {
 379        self.teardown();
 380        *self.id.lock() = id;
 381        self.peer.reset(id.0 as u32);
 382    }
 383
 384    #[cfg(test)]
 385    pub fn id(&self) -> ServerId {
 386        *self.id.lock()
 387    }
 388
 389    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 390    where
 391        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 392        Fut: 'static + Send + Future<Output = Result<()>>,
 393        M: EnvelopedMessage,
 394    {
 395        let prev_handler = self.handlers.insert(
 396            TypeId::of::<M>(),
 397            Box::new(move |envelope, session| {
 398                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 399                let span = info_span!(
 400                    "handle message",
 401                    payload_type = envelope.payload_type_name()
 402                );
 403                span.in_scope(|| {
 404                    tracing::info!(
 405                        payload_type = envelope.payload_type_name(),
 406                        "message received"
 407                    );
 408                });
 409                let start_time = Instant::now();
 410                let future = (handler)(*envelope, session);
 411                async move {
 412                    let result = future.await;
 413                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 414                    match result {
 415                        Err(error) => {
 416                            tracing::error!(%error, ?duration_ms, "error handling message")
 417                        }
 418                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 419                    }
 420                }
 421                .instrument(span)
 422                .boxed()
 423            }),
 424        );
 425        if prev_handler.is_some() {
 426            panic!("registered a handler for the same message twice");
 427        }
 428        self
 429    }
 430
 431    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 432    where
 433        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 434        Fut: 'static + Send + Future<Output = Result<()>>,
 435        M: EnvelopedMessage,
 436    {
 437        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 438        self
 439    }
 440
 441    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 442    where
 443        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 444        Fut: Send + Future<Output = Result<()>>,
 445        M: RequestMessage,
 446    {
 447        let handler = Arc::new(handler);
 448        self.add_handler(move |envelope, session| {
 449            let receipt = envelope.receipt();
 450            let handler = handler.clone();
 451            async move {
 452                let peer = session.peer.clone();
 453                let responded = Arc::new(AtomicBool::default());
 454                let response = Response {
 455                    peer: peer.clone(),
 456                    responded: responded.clone(),
 457                    receipt,
 458                };
 459                match (handler)(envelope.payload, response, session).await {
 460                    Ok(()) => {
 461                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 462                            Ok(())
 463                        } else {
 464                            Err(anyhow!("handler did not send a response"))?
 465                        }
 466                    }
 467                    Err(error) => {
 468                        peer.respond_with_error(
 469                            receipt,
 470                            proto::Error {
 471                                message: error.to_string(),
 472                            },
 473                        )?;
 474                        Err(error)
 475                    }
 476                }
 477            }
 478        })
 479    }
 480
 481    pub fn handle_connection(
 482        self: &Arc<Self>,
 483        connection: Connection,
 484        address: String,
 485        user: User,
 486        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 487        executor: Executor,
 488    ) -> impl Future<Output = Result<()>> {
 489        let this = self.clone();
 490        let user_id = user.id;
 491        let login = user.github_login;
 492        let span = info_span!("handle connection", %user_id, %login, %address);
 493        let mut teardown = self.teardown.subscribe();
 494        async move {
 495            let (connection_id, handle_io, mut incoming_rx) = this
 496                .peer
 497                .add_connection(connection, {
 498                    let executor = executor.clone();
 499                    move |duration| executor.sleep(duration)
 500                });
 501
 502            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 503            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 504            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 505
 506            if let Some(send_connection_id) = send_connection_id.take() {
 507                let _ = send_connection_id.send(connection_id);
 508            }
 509
 510            if !user.connected_once {
 511                this.peer.send(connection_id, proto::ShowContacts {})?;
 512                this.app_state.db.set_user_connected_once(user_id, true).await?;
 513            }
 514
 515            let (contacts, invite_code) = future::try_join(
 516                this.app_state.db.get_contacts(user_id),
 517                this.app_state.db.get_invite_code_for_user(user_id)
 518            ).await?;
 519
 520            {
 521                let mut pool = this.connection_pool.lock();
 522                pool.add_connection(connection_id, user_id, user.admin);
 523                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 524
 525                if let Some((code, count)) = invite_code {
 526                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 527                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 528                        count: count as u32,
 529                    })?;
 530                }
 531            }
 532
 533            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 534                this.peer.send(connection_id, incoming_call)?;
 535            }
 536
 537            let session = Session {
 538                user_id,
 539                connection_id,
 540                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 541                peer: this.peer.clone(),
 542                connection_pool: this.connection_pool.clone(),
 543                live_kit_client: this.app_state.live_kit_client.clone(),
 544                executor: executor.clone(),
 545            };
 546            update_user_contacts(user_id, &session).await?;
 547
 548            let handle_io = handle_io.fuse();
 549            futures::pin_mut!(handle_io);
 550
 551            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 552            // This prevents deadlocks when e.g., client A performs a request to client B and
 553            // client B performs a request to client A. If both clients stop processing further
 554            // messages until their respective request completes, they won't have a chance to
 555            // respond to the other client's request and cause a deadlock.
 556            //
 557            // This arrangement ensures we will attempt to process earlier messages first, but fall
 558            // back to processing messages arrived later in the spirit of making progress.
 559            let mut foreground_message_handlers = FuturesUnordered::new();
 560            let concurrent_handlers = Arc::new(Semaphore::new(256));
 561            loop {
 562                let next_message = async {
 563                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 564                    let message = incoming_rx.next().await;
 565                    (permit, message)
 566                }.fuse();
 567                futures::pin_mut!(next_message);
 568                futures::select_biased! {
 569                    _ = teardown.changed().fuse() => return Ok(()),
 570                    result = handle_io => {
 571                        if let Err(error) = result {
 572                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 573                        }
 574                        break;
 575                    }
 576                    _ = foreground_message_handlers.next() => {}
 577                    next_message = next_message => {
 578                        let (permit, message) = next_message;
 579                        if let Some(message) = message {
 580                            let type_name = message.payload_type_name();
 581                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 582                            let span_enter = span.enter();
 583                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 584                                let is_background = message.is_background();
 585                                let handle_message = (handler)(message, session.clone());
 586                                drop(span_enter);
 587
 588                                let handle_message = async move {
 589                                    handle_message.await;
 590                                    drop(permit);
 591                                }.instrument(span);
 592                                if is_background {
 593                                    executor.spawn_detached(handle_message);
 594                                } else {
 595                                    foreground_message_handlers.push(handle_message);
 596                                }
 597                            } else {
 598                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 599                            }
 600                        } else {
 601                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 602                            break;
 603                        }
 604                    }
 605                }
 606            }
 607
 608            drop(foreground_message_handlers);
 609            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 610            if let Err(error) = connection_lost(session, teardown, executor).await {
 611                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 612            }
 613
 614            Ok(())
 615        }.instrument(span)
 616    }
 617
 618    pub async fn invite_code_redeemed(
 619        self: &Arc<Self>,
 620        inviter_id: UserId,
 621        invitee_id: UserId,
 622    ) -> Result<()> {
 623        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 624            if let Some(code) = &user.invite_code {
 625                let pool = self.connection_pool.lock();
 626                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 627                for connection_id in pool.user_connection_ids(inviter_id) {
 628                    self.peer.send(
 629                        connection_id,
 630                        proto::UpdateContacts {
 631                            contacts: vec![invitee_contact.clone()],
 632                            ..Default::default()
 633                        },
 634                    )?;
 635                    self.peer.send(
 636                        connection_id,
 637                        proto::UpdateInviteInfo {
 638                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 639                            count: user.invite_count as u32,
 640                        },
 641                    )?;
 642                }
 643            }
 644        }
 645        Ok(())
 646    }
 647
 648    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 649        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 650            if let Some(invite_code) = &user.invite_code {
 651                let pool = self.connection_pool.lock();
 652                for connection_id in pool.user_connection_ids(user_id) {
 653                    self.peer.send(
 654                        connection_id,
 655                        proto::UpdateInviteInfo {
 656                            url: format!(
 657                                "{}{}",
 658                                self.app_state.config.invite_link_prefix, invite_code
 659                            ),
 660                            count: user.invite_count as u32,
 661                        },
 662                    )?;
 663                }
 664            }
 665        }
 666        Ok(())
 667    }
 668
 669    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 670        ServerSnapshot {
 671            connection_pool: ConnectionPoolGuard {
 672                guard: self.connection_pool.lock(),
 673                _not_send: PhantomData,
 674            },
 675            peer: &self.peer,
 676        }
 677    }
 678}
 679
 680impl<'a> Deref for ConnectionPoolGuard<'a> {
 681    type Target = ConnectionPool;
 682
 683    fn deref(&self) -> &Self::Target {
 684        &*self.guard
 685    }
 686}
 687
 688impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 689    fn deref_mut(&mut self) -> &mut Self::Target {
 690        &mut *self.guard
 691    }
 692}
 693
 694impl<'a> Drop for ConnectionPoolGuard<'a> {
 695    fn drop(&mut self) {
 696        #[cfg(test)]
 697        self.check_invariants();
 698    }
 699}
 700
 701fn broadcast<F>(
 702    sender_id: Option<ConnectionId>,
 703    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 704    mut f: F,
 705) where
 706    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 707{
 708    for receiver_id in receiver_ids {
 709        if Some(receiver_id) != sender_id {
 710            if let Err(error) = f(receiver_id) {
 711                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 712            }
 713        }
 714    }
 715}
 716
 717lazy_static! {
 718    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 719}
 720
 721pub struct ProtocolVersion(u32);
 722
 723impl Header for ProtocolVersion {
 724    fn name() -> &'static HeaderName {
 725        &ZED_PROTOCOL_VERSION
 726    }
 727
 728    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 729    where
 730        Self: Sized,
 731        I: Iterator<Item = &'i axum::http::HeaderValue>,
 732    {
 733        let version = values
 734            .next()
 735            .ok_or_else(axum::headers::Error::invalid)?
 736            .to_str()
 737            .map_err(|_| axum::headers::Error::invalid())?
 738            .parse()
 739            .map_err(|_| axum::headers::Error::invalid())?;
 740        Ok(Self(version))
 741    }
 742
 743    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 744        values.extend([self.0.to_string().parse().unwrap()]);
 745    }
 746}
 747
 748pub fn routes(server: Arc<Server>) -> Router<Body> {
 749    Router::new()
 750        .route("/rpc", get(handle_websocket_request))
 751        .layer(
 752            ServiceBuilder::new()
 753                .layer(Extension(server.app_state.clone()))
 754                .layer(middleware::from_fn(auth::validate_header)),
 755        )
 756        .route("/metrics", get(handle_metrics))
 757        .layer(Extension(server))
 758}
 759
 760pub async fn handle_websocket_request(
 761    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 762    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 763    Extension(server): Extension<Arc<Server>>,
 764    Extension(user): Extension<User>,
 765    ws: WebSocketUpgrade,
 766) -> axum::response::Response {
 767    if protocol_version != rpc::PROTOCOL_VERSION {
 768        return (
 769            StatusCode::UPGRADE_REQUIRED,
 770            "client must be upgraded".to_string(),
 771        )
 772            .into_response();
 773    }
 774    let socket_address = socket_address.to_string();
 775    ws.on_upgrade(move |socket| {
 776        use util::ResultExt;
 777        let socket = socket
 778            .map_ok(to_tungstenite_message)
 779            .err_into()
 780            .with(|message| async move { Ok(to_axum_message(message)) });
 781        let connection = Connection::new(Box::pin(socket));
 782        async move {
 783            server
 784                .handle_connection(connection, socket_address, user, None, Executor::Production)
 785                .await
 786                .log_err();
 787        }
 788    })
 789}
 790
 791pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 792    let connections = server
 793        .connection_pool
 794        .lock()
 795        .connections()
 796        .filter(|connection| !connection.admin)
 797        .count();
 798
 799    METRIC_CONNECTIONS.set(connections as _);
 800
 801    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 802    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 803
 804    let encoder = prometheus::TextEncoder::new();
 805    let metric_families = prometheus::gather();
 806    let encoded_metrics = encoder
 807        .encode_to_string(&metric_families)
 808        .map_err(|err| anyhow!("{}", err))?;
 809    Ok(encoded_metrics)
 810}
 811
 812#[instrument(err, skip(executor))]
 813async fn connection_lost(
 814    session: Session,
 815    mut teardown: watch::Receiver<()>,
 816    executor: Executor,
 817) -> Result<()> {
 818    session.peer.disconnect(session.connection_id);
 819    session
 820        .connection_pool()
 821        .await
 822        .remove_connection(session.connection_id)?;
 823
 824    session
 825        .db()
 826        .await
 827        .connection_lost(session.connection_id)
 828        .await
 829        .trace_err();
 830
 831    futures::select_biased! {
 832        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 833            leave_room_for_session(&session).await.trace_err();
 834
 835            if !session
 836                .connection_pool()
 837                .await
 838                .is_user_online(session.user_id)
 839            {
 840                let db = session.db().await;
 841                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 842                    room_updated(&room, &session.peer);
 843                }
 844            }
 845            update_user_contacts(session.user_id, &session).await?;
 846        }
 847        _ = teardown.changed().fuse() => {}
 848    }
 849
 850    Ok(())
 851}
 852
 853async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 854    response.send(proto::Ack {})?;
 855    Ok(())
 856}
 857
 858async fn create_room(
 859    _request: proto::CreateRoom,
 860    response: Response<proto::CreateRoom>,
 861    session: Session,
 862) -> Result<()> {
 863    let live_kit_room = nanoid::nanoid!(30);
 864    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 865        if let Some(_) = live_kit
 866            .create_room(live_kit_room.clone())
 867            .await
 868            .trace_err()
 869        {
 870            if let Some(token) = live_kit
 871                .room_token(&live_kit_room, &session.user_id.to_string())
 872                .trace_err()
 873            {
 874                Some(proto::LiveKitConnectionInfo {
 875                    server_url: live_kit.url().into(),
 876                    token,
 877                })
 878            } else {
 879                None
 880            }
 881        } else {
 882            None
 883        }
 884    } else {
 885        None
 886    };
 887
 888    {
 889        let room = session
 890            .db()
 891            .await
 892            .create_room(session.user_id, session.connection_id, &live_kit_room)
 893            .await?;
 894
 895        response.send(proto::CreateRoomResponse {
 896            room: Some(room.clone()),
 897            live_kit_connection_info,
 898        })?;
 899    }
 900
 901    update_user_contacts(session.user_id, &session).await?;
 902    Ok(())
 903}
 904
 905async fn join_room(
 906    request: proto::JoinRoom,
 907    response: Response<proto::JoinRoom>,
 908    session: Session,
 909) -> Result<()> {
 910    let room_id = RoomId::from_proto(request.id);
 911    let room = {
 912        let room = session
 913            .db()
 914            .await
 915            .join_room(room_id, session.user_id, session.connection_id)
 916            .await?;
 917        room_updated(&room, &session.peer);
 918        room.clone()
 919    };
 920
 921    for connection_id in session
 922        .connection_pool()
 923        .await
 924        .user_connection_ids(session.user_id)
 925    {
 926        session
 927            .peer
 928            .send(
 929                connection_id,
 930                proto::CallCanceled {
 931                    room_id: room_id.to_proto(),
 932                },
 933            )
 934            .trace_err();
 935    }
 936
 937    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 938        if let Some(token) = live_kit
 939            .room_token(&room.live_kit_room, &session.user_id.to_string())
 940            .trace_err()
 941        {
 942            Some(proto::LiveKitConnectionInfo {
 943                server_url: live_kit.url().into(),
 944                token,
 945            })
 946        } else {
 947            None
 948        }
 949    } else {
 950        None
 951    };
 952
 953    response.send(proto::JoinRoomResponse {
 954        room: Some(room),
 955        live_kit_connection_info,
 956    })?;
 957
 958    update_user_contacts(session.user_id, &session).await?;
 959    Ok(())
 960}
 961
 962async fn rejoin_room(
 963    request: proto::RejoinRoom,
 964    response: Response<proto::RejoinRoom>,
 965    session: Session,
 966) -> Result<()> {
 967    {
 968        let mut rejoined_room = session
 969            .db()
 970            .await
 971            .rejoin_room(request, session.user_id, session.connection_id)
 972            .await?;
 973
 974        response.send(proto::RejoinRoomResponse {
 975            room: Some(rejoined_room.room.clone()),
 976            reshared_projects: rejoined_room
 977                .reshared_projects
 978                .iter()
 979                .map(|project| proto::ResharedProject {
 980                    id: project.id.to_proto(),
 981                    collaborators: project
 982                        .collaborators
 983                        .iter()
 984                        .map(|collaborator| collaborator.to_proto())
 985                        .collect(),
 986                })
 987                .collect(),
 988            rejoined_projects: rejoined_room
 989                .rejoined_projects
 990                .iter()
 991                .map(|rejoined_project| proto::RejoinedProject {
 992                    id: rejoined_project.id.to_proto(),
 993                    worktrees: rejoined_project
 994                        .worktrees
 995                        .iter()
 996                        .map(|worktree| proto::WorktreeMetadata {
 997                            id: worktree.id,
 998                            root_name: worktree.root_name.clone(),
 999                            visible: worktree.visible,
1000                            abs_path: worktree.abs_path.clone(),
1001                        })
1002                        .collect(),
1003                    collaborators: rejoined_project
1004                        .collaborators
1005                        .iter()
1006                        .map(|collaborator| collaborator.to_proto())
1007                        .collect(),
1008                    language_servers: rejoined_project.language_servers.clone(),
1009                })
1010                .collect(),
1011        })?;
1012        room_updated(&rejoined_room.room, &session.peer);
1013
1014        for project in &rejoined_room.reshared_projects {
1015            for collaborator in &project.collaborators {
1016                session
1017                    .peer
1018                    .send(
1019                        collaborator.connection_id,
1020                        proto::UpdateProjectCollaborator {
1021                            project_id: project.id.to_proto(),
1022                            old_peer_id: Some(project.old_connection_id.into()),
1023                            new_peer_id: Some(session.connection_id.into()),
1024                        },
1025                    )
1026                    .trace_err();
1027            }
1028
1029            broadcast(
1030                Some(session.connection_id),
1031                project
1032                    .collaborators
1033                    .iter()
1034                    .map(|collaborator| collaborator.connection_id),
1035                |connection_id| {
1036                    session.peer.forward_send(
1037                        session.connection_id,
1038                        connection_id,
1039                        proto::UpdateProject {
1040                            project_id: project.id.to_proto(),
1041                            worktrees: project.worktrees.clone(),
1042                        },
1043                    )
1044                },
1045            );
1046        }
1047
1048        for project in &rejoined_room.rejoined_projects {
1049            for collaborator in &project.collaborators {
1050                session
1051                    .peer
1052                    .send(
1053                        collaborator.connection_id,
1054                        proto::UpdateProjectCollaborator {
1055                            project_id: project.id.to_proto(),
1056                            old_peer_id: Some(project.old_connection_id.into()),
1057                            new_peer_id: Some(session.connection_id.into()),
1058                        },
1059                    )
1060                    .trace_err();
1061            }
1062        }
1063
1064        for project in &mut rejoined_room.rejoined_projects {
1065            for worktree in mem::take(&mut project.worktrees) {
1066                #[cfg(any(test, feature = "test-support"))]
1067                const MAX_CHUNK_SIZE: usize = 2;
1068                #[cfg(not(any(test, feature = "test-support")))]
1069                const MAX_CHUNK_SIZE: usize = 256;
1070
1071                // Stream this worktree's entries.
1072                let message = proto::UpdateWorktree {
1073                    project_id: project.id.to_proto(),
1074                    worktree_id: worktree.id,
1075                    abs_path: worktree.abs_path.clone(),
1076                    root_name: worktree.root_name,
1077                    updated_entries: worktree.updated_entries,
1078                    removed_entries: worktree.removed_entries,
1079                    scan_id: worktree.scan_id,
1080                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1081                    updated_repositories: worktree.updated_repositories,
1082                    removed_repositories: worktree.removed_repositories,
1083                };
1084                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1085                    session.peer.send(session.connection_id, update.clone())?;
1086                }
1087
1088                // Stream this worktree's diagnostics.
1089                for summary in worktree.diagnostic_summaries {
1090                    session.peer.send(
1091                        session.connection_id,
1092                        proto::UpdateDiagnosticSummary {
1093                            project_id: project.id.to_proto(),
1094                            worktree_id: worktree.id,
1095                            summary: Some(summary),
1096                        },
1097                    )?;
1098                }
1099
1100                for settings_file in worktree.settings_files {
1101                    session.peer.send(
1102                        session.connection_id,
1103                        proto::UpdateWorktreeSettings {
1104                            project_id: project.id.to_proto(),
1105                            worktree_id: worktree.id,
1106                            path: settings_file.path,
1107                            content: Some(settings_file.content),
1108                        },
1109                    )?;
1110                }
1111            }
1112
1113            for language_server in &project.language_servers {
1114                session.peer.send(
1115                    session.connection_id,
1116                    proto::UpdateLanguageServer {
1117                        project_id: project.id.to_proto(),
1118                        language_server_id: language_server.id,
1119                        variant: Some(
1120                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1121                                proto::LspDiskBasedDiagnosticsUpdated {},
1122                            ),
1123                        ),
1124                    },
1125                )?;
1126            }
1127        }
1128    }
1129
1130    update_user_contacts(session.user_id, &session).await?;
1131    Ok(())
1132}
1133
1134async fn leave_room(
1135    _: proto::LeaveRoom,
1136    response: Response<proto::LeaveRoom>,
1137    session: Session,
1138) -> Result<()> {
1139    leave_room_for_session(&session).await?;
1140    response.send(proto::Ack {})?;
1141    Ok(())
1142}
1143
1144async fn call(
1145    request: proto::Call,
1146    response: Response<proto::Call>,
1147    session: Session,
1148) -> Result<()> {
1149    let room_id = RoomId::from_proto(request.room_id);
1150    let calling_user_id = session.user_id;
1151    let calling_connection_id = session.connection_id;
1152    let called_user_id = UserId::from_proto(request.called_user_id);
1153    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1154    if !session
1155        .db()
1156        .await
1157        .has_contact(calling_user_id, called_user_id)
1158        .await?
1159    {
1160        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1161    }
1162
1163    let incoming_call = {
1164        let (room, incoming_call) = &mut *session
1165            .db()
1166            .await
1167            .call(
1168                room_id,
1169                calling_user_id,
1170                calling_connection_id,
1171                called_user_id,
1172                initial_project_id,
1173            )
1174            .await?;
1175        room_updated(&room, &session.peer);
1176        mem::take(incoming_call)
1177    };
1178    update_user_contacts(called_user_id, &session).await?;
1179
1180    let mut calls = session
1181        .connection_pool()
1182        .await
1183        .user_connection_ids(called_user_id)
1184        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1185        .collect::<FuturesUnordered<_>>();
1186
1187    while let Some(call_response) = calls.next().await {
1188        match call_response.as_ref() {
1189            Ok(_) => {
1190                response.send(proto::Ack {})?;
1191                return Ok(());
1192            }
1193            Err(_) => {
1194                call_response.trace_err();
1195            }
1196        }
1197    }
1198
1199    {
1200        let room = session
1201            .db()
1202            .await
1203            .call_failed(room_id, called_user_id)
1204            .await?;
1205        room_updated(&room, &session.peer);
1206    }
1207    update_user_contacts(called_user_id, &session).await?;
1208
1209    Err(anyhow!("failed to ring user"))?
1210}
1211
1212async fn cancel_call(
1213    request: proto::CancelCall,
1214    response: Response<proto::CancelCall>,
1215    session: Session,
1216) -> Result<()> {
1217    let called_user_id = UserId::from_proto(request.called_user_id);
1218    let room_id = RoomId::from_proto(request.room_id);
1219    {
1220        let room = session
1221            .db()
1222            .await
1223            .cancel_call(room_id, session.connection_id, called_user_id)
1224            .await?;
1225        room_updated(&room, &session.peer);
1226    }
1227
1228    for connection_id in session
1229        .connection_pool()
1230        .await
1231        .user_connection_ids(called_user_id)
1232    {
1233        session
1234            .peer
1235            .send(
1236                connection_id,
1237                proto::CallCanceled {
1238                    room_id: room_id.to_proto(),
1239                },
1240            )
1241            .trace_err();
1242    }
1243    response.send(proto::Ack {})?;
1244
1245    update_user_contacts(called_user_id, &session).await?;
1246    Ok(())
1247}
1248
1249async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1250    let room_id = RoomId::from_proto(message.room_id);
1251    {
1252        let room = session
1253            .db()
1254            .await
1255            .decline_call(Some(room_id), session.user_id)
1256            .await?
1257            .ok_or_else(|| anyhow!("failed to decline call"))?;
1258        room_updated(&room, &session.peer);
1259    }
1260
1261    for connection_id in session
1262        .connection_pool()
1263        .await
1264        .user_connection_ids(session.user_id)
1265    {
1266        session
1267            .peer
1268            .send(
1269                connection_id,
1270                proto::CallCanceled {
1271                    room_id: room_id.to_proto(),
1272                },
1273            )
1274            .trace_err();
1275    }
1276    update_user_contacts(session.user_id, &session).await?;
1277    Ok(())
1278}
1279
1280async fn update_participant_location(
1281    request: proto::UpdateParticipantLocation,
1282    response: Response<proto::UpdateParticipantLocation>,
1283    session: Session,
1284) -> Result<()> {
1285    let room_id = RoomId::from_proto(request.room_id);
1286    let location = request
1287        .location
1288        .ok_or_else(|| anyhow!("invalid location"))?;
1289    let room = session
1290        .db()
1291        .await
1292        .update_room_participant_location(room_id, session.connection_id, location)
1293        .await?;
1294    room_updated(&room, &session.peer);
1295    response.send(proto::Ack {})?;
1296    Ok(())
1297}
1298
1299async fn share_project(
1300    request: proto::ShareProject,
1301    response: Response<proto::ShareProject>,
1302    session: Session,
1303) -> Result<()> {
1304    let (project_id, room) = &*session
1305        .db()
1306        .await
1307        .share_project(
1308            RoomId::from_proto(request.room_id),
1309            session.connection_id,
1310            &request.worktrees,
1311        )
1312        .await?;
1313    response.send(proto::ShareProjectResponse {
1314        project_id: project_id.to_proto(),
1315    })?;
1316    room_updated(&room, &session.peer);
1317
1318    Ok(())
1319}
1320
1321async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1322    let project_id = ProjectId::from_proto(message.project_id);
1323
1324    let (room, guest_connection_ids) = &*session
1325        .db()
1326        .await
1327        .unshare_project(project_id, session.connection_id)
1328        .await?;
1329
1330    broadcast(
1331        Some(session.connection_id),
1332        guest_connection_ids.iter().copied(),
1333        |conn_id| session.peer.send(conn_id, message.clone()),
1334    );
1335    room_updated(&room, &session.peer);
1336
1337    Ok(())
1338}
1339
1340async fn join_project(
1341    request: proto::JoinProject,
1342    response: Response<proto::JoinProject>,
1343    session: Session,
1344) -> Result<()> {
1345    let project_id = ProjectId::from_proto(request.project_id);
1346    let guest_user_id = session.user_id;
1347
1348    tracing::info!(%project_id, "join project");
1349
1350    let (project, replica_id) = &mut *session
1351        .db()
1352        .await
1353        .join_project(project_id, session.connection_id)
1354        .await?;
1355
1356    let collaborators = project
1357        .collaborators
1358        .iter()
1359        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1360        .map(|collaborator| collaborator.to_proto())
1361        .collect::<Vec<_>>();
1362
1363    let worktrees = project
1364        .worktrees
1365        .iter()
1366        .map(|(id, worktree)| proto::WorktreeMetadata {
1367            id: *id,
1368            root_name: worktree.root_name.clone(),
1369            visible: worktree.visible,
1370            abs_path: worktree.abs_path.clone(),
1371        })
1372        .collect::<Vec<_>>();
1373
1374    for collaborator in &collaborators {
1375        session
1376            .peer
1377            .send(
1378                collaborator.peer_id.unwrap().into(),
1379                proto::AddProjectCollaborator {
1380                    project_id: project_id.to_proto(),
1381                    collaborator: Some(proto::Collaborator {
1382                        peer_id: Some(session.connection_id.into()),
1383                        replica_id: replica_id.0 as u32,
1384                        user_id: guest_user_id.to_proto(),
1385                    }),
1386                },
1387            )
1388            .trace_err();
1389    }
1390
1391    // First, we send the metadata associated with each worktree.
1392    response.send(proto::JoinProjectResponse {
1393        worktrees: worktrees.clone(),
1394        replica_id: replica_id.0 as u32,
1395        collaborators: collaborators.clone(),
1396        language_servers: project.language_servers.clone(),
1397    })?;
1398
1399    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1400        #[cfg(any(test, feature = "test-support"))]
1401        const MAX_CHUNK_SIZE: usize = 2;
1402        #[cfg(not(any(test, feature = "test-support")))]
1403        const MAX_CHUNK_SIZE: usize = 256;
1404
1405        // Stream this worktree's entries.
1406        let message = proto::UpdateWorktree {
1407            project_id: project_id.to_proto(),
1408            worktree_id,
1409            abs_path: worktree.abs_path.clone(),
1410            root_name: worktree.root_name,
1411            updated_entries: worktree.entries,
1412            removed_entries: Default::default(),
1413            scan_id: worktree.scan_id,
1414            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1415            updated_repositories: worktree.repository_entries.into_values().collect(),
1416            removed_repositories: Default::default(),
1417        };
1418        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1419            session.peer.send(session.connection_id, update.clone())?;
1420        }
1421
1422        // Stream this worktree's diagnostics.
1423        for summary in worktree.diagnostic_summaries {
1424            session.peer.send(
1425                session.connection_id,
1426                proto::UpdateDiagnosticSummary {
1427                    project_id: project_id.to_proto(),
1428                    worktree_id: worktree.id,
1429                    summary: Some(summary),
1430                },
1431            )?;
1432        }
1433
1434        for settings_file in worktree.settings_files {
1435            session.peer.send(
1436                session.connection_id,
1437                proto::UpdateWorktreeSettings {
1438                    project_id: project_id.to_proto(),
1439                    worktree_id: worktree.id,
1440                    path: settings_file.path,
1441                    content: Some(settings_file.content),
1442                },
1443            )?;
1444        }
1445    }
1446
1447    for language_server in &project.language_servers {
1448        session.peer.send(
1449            session.connection_id,
1450            proto::UpdateLanguageServer {
1451                project_id: project_id.to_proto(),
1452                language_server_id: language_server.id,
1453                variant: Some(
1454                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1455                        proto::LspDiskBasedDiagnosticsUpdated {},
1456                    ),
1457                ),
1458            },
1459        )?;
1460    }
1461
1462    Ok(())
1463}
1464
1465async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1466    let sender_id = session.connection_id;
1467    let project_id = ProjectId::from_proto(request.project_id);
1468
1469    let (room, project) = &*session
1470        .db()
1471        .await
1472        .leave_project(project_id, sender_id)
1473        .await?;
1474    tracing::info!(
1475        %project_id,
1476        host_user_id = %project.host_user_id,
1477        host_connection_id = %project.host_connection_id,
1478        "leave project"
1479    );
1480
1481    project_left(&project, &session);
1482    room_updated(&room, &session.peer);
1483
1484    Ok(())
1485}
1486
1487async fn update_project(
1488    request: proto::UpdateProject,
1489    response: Response<proto::UpdateProject>,
1490    session: Session,
1491) -> Result<()> {
1492    let project_id = ProjectId::from_proto(request.project_id);
1493    let (room, guest_connection_ids) = &*session
1494        .db()
1495        .await
1496        .update_project(project_id, session.connection_id, &request.worktrees)
1497        .await?;
1498    broadcast(
1499        Some(session.connection_id),
1500        guest_connection_ids.iter().copied(),
1501        |connection_id| {
1502            session
1503                .peer
1504                .forward_send(session.connection_id, connection_id, request.clone())
1505        },
1506    );
1507    room_updated(&room, &session.peer);
1508    response.send(proto::Ack {})?;
1509
1510    Ok(())
1511}
1512
1513async fn update_worktree(
1514    request: proto::UpdateWorktree,
1515    response: Response<proto::UpdateWorktree>,
1516    session: Session,
1517) -> Result<()> {
1518    let guest_connection_ids = session
1519        .db()
1520        .await
1521        .update_worktree(&request, session.connection_id)
1522        .await?;
1523
1524    broadcast(
1525        Some(session.connection_id),
1526        guest_connection_ids.iter().copied(),
1527        |connection_id| {
1528            session
1529                .peer
1530                .forward_send(session.connection_id, connection_id, request.clone())
1531        },
1532    );
1533    response.send(proto::Ack {})?;
1534    Ok(())
1535}
1536
1537async fn update_diagnostic_summary(
1538    message: proto::UpdateDiagnosticSummary,
1539    session: Session,
1540) -> Result<()> {
1541    let guest_connection_ids = session
1542        .db()
1543        .await
1544        .update_diagnostic_summary(&message, session.connection_id)
1545        .await?;
1546
1547    broadcast(
1548        Some(session.connection_id),
1549        guest_connection_ids.iter().copied(),
1550        |connection_id| {
1551            session
1552                .peer
1553                .forward_send(session.connection_id, connection_id, message.clone())
1554        },
1555    );
1556
1557    Ok(())
1558}
1559
1560async fn update_worktree_settings(
1561    message: proto::UpdateWorktreeSettings,
1562    session: Session,
1563) -> Result<()> {
1564    let guest_connection_ids = session
1565        .db()
1566        .await
1567        .update_worktree_settings(&message, session.connection_id)
1568        .await?;
1569
1570    broadcast(
1571        Some(session.connection_id),
1572        guest_connection_ids.iter().copied(),
1573        |connection_id| {
1574            session
1575                .peer
1576                .forward_send(session.connection_id, connection_id, message.clone())
1577        },
1578    );
1579
1580    Ok(())
1581}
1582
1583async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1584    broadcast_project_message(request.project_id, request, session).await
1585}
1586
1587async fn start_language_server(
1588    request: proto::StartLanguageServer,
1589    session: Session,
1590) -> Result<()> {
1591    let guest_connection_ids = session
1592        .db()
1593        .await
1594        .start_language_server(&request, session.connection_id)
1595        .await?;
1596
1597    broadcast(
1598        Some(session.connection_id),
1599        guest_connection_ids.iter().copied(),
1600        |connection_id| {
1601            session
1602                .peer
1603                .forward_send(session.connection_id, connection_id, request.clone())
1604        },
1605    );
1606    Ok(())
1607}
1608
1609async fn update_language_server(
1610    request: proto::UpdateLanguageServer,
1611    session: Session,
1612) -> Result<()> {
1613    session.executor.record_backtrace();
1614    let project_id = ProjectId::from_proto(request.project_id);
1615    let project_connection_ids = session
1616        .db()
1617        .await
1618        .project_connection_ids(project_id, session.connection_id)
1619        .await?;
1620    broadcast(
1621        Some(session.connection_id),
1622        project_connection_ids.iter().copied(),
1623        |connection_id| {
1624            session
1625                .peer
1626                .forward_send(session.connection_id, connection_id, request.clone())
1627        },
1628    );
1629    Ok(())
1630}
1631
1632async fn forward_project_request<T>(
1633    request: T,
1634    response: Response<T>,
1635    session: Session,
1636) -> Result<()>
1637where
1638    T: EntityMessage + RequestMessage,
1639{
1640    session.executor.record_backtrace();
1641    let project_id = ProjectId::from_proto(request.remote_entity_id());
1642    let host_connection_id = {
1643        let collaborators = session
1644            .db()
1645            .await
1646            .project_collaborators(project_id, session.connection_id)
1647            .await?;
1648        collaborators
1649            .iter()
1650            .find(|collaborator| collaborator.is_host)
1651            .ok_or_else(|| anyhow!("host not found"))?
1652            .connection_id
1653    };
1654
1655    let payload = session
1656        .peer
1657        .forward_request(session.connection_id, host_connection_id, request)
1658        .await?;
1659
1660    response.send(payload)?;
1661    Ok(())
1662}
1663
1664async fn create_buffer_for_peer(
1665    request: proto::CreateBufferForPeer,
1666    session: Session,
1667) -> Result<()> {
1668    session.executor.record_backtrace();
1669    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1670    session
1671        .peer
1672        .forward_send(session.connection_id, peer_id.into(), request)?;
1673    Ok(())
1674}
1675
1676async fn update_buffer(
1677    request: proto::UpdateBuffer,
1678    response: Response<proto::UpdateBuffer>,
1679    session: Session,
1680) -> Result<()> {
1681    session.executor.record_backtrace();
1682    let project_id = ProjectId::from_proto(request.project_id);
1683    let mut guest_connection_ids;
1684    let mut host_connection_id = None;
1685    {
1686        let collaborators = session
1687            .db()
1688            .await
1689            .project_collaborators(project_id, session.connection_id)
1690            .await?;
1691        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1692        for collaborator in collaborators.iter() {
1693            if collaborator.is_host {
1694                host_connection_id = Some(collaborator.connection_id);
1695            } else {
1696                guest_connection_ids.push(collaborator.connection_id);
1697            }
1698        }
1699    }
1700    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1701
1702    session.executor.record_backtrace();
1703    broadcast(
1704        Some(session.connection_id),
1705        guest_connection_ids,
1706        |connection_id| {
1707            session
1708                .peer
1709                .forward_send(session.connection_id, connection_id, request.clone())
1710        },
1711    );
1712    if host_connection_id != session.connection_id {
1713        session
1714            .peer
1715            .forward_request(session.connection_id, host_connection_id, request.clone())
1716            .await?;
1717    }
1718
1719    response.send(proto::Ack {})?;
1720    Ok(())
1721}
1722
1723async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1724    let project_id = ProjectId::from_proto(request.project_id);
1725    let project_connection_ids = session
1726        .db()
1727        .await
1728        .project_connection_ids(project_id, session.connection_id)
1729        .await?;
1730
1731    broadcast(
1732        Some(session.connection_id),
1733        project_connection_ids.iter().copied(),
1734        |connection_id| {
1735            session
1736                .peer
1737                .forward_send(session.connection_id, connection_id, request.clone())
1738        },
1739    );
1740    Ok(())
1741}
1742
1743async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1744    let project_id = ProjectId::from_proto(request.project_id);
1745    let project_connection_ids = session
1746        .db()
1747        .await
1748        .project_connection_ids(project_id, session.connection_id)
1749        .await?;
1750    broadcast(
1751        Some(session.connection_id),
1752        project_connection_ids.iter().copied(),
1753        |connection_id| {
1754            session
1755                .peer
1756                .forward_send(session.connection_id, connection_id, request.clone())
1757        },
1758    );
1759    Ok(())
1760}
1761
1762async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1763    broadcast_project_message(request.project_id, request, session).await
1764}
1765
1766async fn broadcast_project_message<T: EnvelopedMessage>(
1767    project_id: u64,
1768    request: T,
1769    session: Session,
1770) -> Result<()> {
1771    let project_id = ProjectId::from_proto(project_id);
1772    let project_connection_ids = session
1773        .db()
1774        .await
1775        .project_connection_ids(project_id, session.connection_id)
1776        .await?;
1777    broadcast(
1778        Some(session.connection_id),
1779        project_connection_ids.iter().copied(),
1780        |connection_id| {
1781            session
1782                .peer
1783                .forward_send(session.connection_id, connection_id, request.clone())
1784        },
1785    );
1786    Ok(())
1787}
1788
1789async fn follow(
1790    request: proto::Follow,
1791    response: Response<proto::Follow>,
1792    session: Session,
1793) -> Result<()> {
1794    let project_id = ProjectId::from_proto(request.project_id);
1795    let leader_id = request
1796        .leader_id
1797        .ok_or_else(|| anyhow!("invalid leader id"))?
1798        .into();
1799    let follower_id = session.connection_id;
1800
1801    {
1802        let project_connection_ids = session
1803            .db()
1804            .await
1805            .project_connection_ids(project_id, session.connection_id)
1806            .await?;
1807
1808        if !project_connection_ids.contains(&leader_id) {
1809            Err(anyhow!("no such peer"))?;
1810        }
1811    }
1812
1813    let mut response_payload = session
1814        .peer
1815        .forward_request(session.connection_id, leader_id, request)
1816        .await?;
1817    response_payload
1818        .views
1819        .retain(|view| view.leader_id != Some(follower_id.into()));
1820    response.send(response_payload)?;
1821
1822    let room = session
1823        .db()
1824        .await
1825        .follow(project_id, leader_id, follower_id)
1826        .await?;
1827    room_updated(&room, &session.peer);
1828
1829    Ok(())
1830}
1831
1832async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1833    let project_id = ProjectId::from_proto(request.project_id);
1834    let leader_id = request
1835        .leader_id
1836        .ok_or_else(|| anyhow!("invalid leader id"))?
1837        .into();
1838    let follower_id = session.connection_id;
1839
1840    if !session
1841        .db()
1842        .await
1843        .project_connection_ids(project_id, session.connection_id)
1844        .await?
1845        .contains(&leader_id)
1846    {
1847        Err(anyhow!("no such peer"))?;
1848    }
1849
1850    session
1851        .peer
1852        .forward_send(session.connection_id, leader_id, request)?;
1853
1854    let room = session
1855        .db()
1856        .await
1857        .unfollow(project_id, leader_id, follower_id)
1858        .await?;
1859    room_updated(&room, &session.peer);
1860
1861    Ok(())
1862}
1863
1864async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1865    let project_id = ProjectId::from_proto(request.project_id);
1866    let project_connection_ids = session
1867        .db
1868        .lock()
1869        .await
1870        .project_connection_ids(project_id, session.connection_id)
1871        .await?;
1872
1873    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1874        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1875        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1876        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1877    });
1878    for follower_peer_id in request.follower_ids.iter().copied() {
1879        let follower_connection_id = follower_peer_id.into();
1880        if project_connection_ids.contains(&follower_connection_id)
1881            && Some(follower_peer_id) != leader_id
1882        {
1883            session.peer.forward_send(
1884                session.connection_id,
1885                follower_connection_id,
1886                request.clone(),
1887            )?;
1888        }
1889    }
1890    Ok(())
1891}
1892
1893async fn get_users(
1894    request: proto::GetUsers,
1895    response: Response<proto::GetUsers>,
1896    session: Session,
1897) -> Result<()> {
1898    let user_ids = request
1899        .user_ids
1900        .into_iter()
1901        .map(UserId::from_proto)
1902        .collect();
1903    let users = session
1904        .db()
1905        .await
1906        .get_users_by_ids(user_ids)
1907        .await?
1908        .into_iter()
1909        .map(|user| proto::User {
1910            id: user.id.to_proto(),
1911            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1912            github_login: user.github_login,
1913        })
1914        .collect();
1915    response.send(proto::UsersResponse { users })?;
1916    Ok(())
1917}
1918
1919async fn fuzzy_search_users(
1920    request: proto::FuzzySearchUsers,
1921    response: Response<proto::FuzzySearchUsers>,
1922    session: Session,
1923) -> Result<()> {
1924    let query = request.query;
1925    let users = match query.len() {
1926        0 => vec![],
1927        1 | 2 => session
1928            .db()
1929            .await
1930            .get_user_by_github_login(&query)
1931            .await?
1932            .into_iter()
1933            .collect(),
1934        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1935    };
1936    let users = users
1937        .into_iter()
1938        .filter(|user| user.id != session.user_id)
1939        .map(|user| proto::User {
1940            id: user.id.to_proto(),
1941            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1942            github_login: user.github_login,
1943        })
1944        .collect();
1945    response.send(proto::UsersResponse { users })?;
1946    Ok(())
1947}
1948
1949async fn request_contact(
1950    request: proto::RequestContact,
1951    response: Response<proto::RequestContact>,
1952    session: Session,
1953) -> Result<()> {
1954    let requester_id = session.user_id;
1955    let responder_id = UserId::from_proto(request.responder_id);
1956    if requester_id == responder_id {
1957        return Err(anyhow!("cannot add yourself as a contact"))?;
1958    }
1959
1960    session
1961        .db()
1962        .await
1963        .send_contact_request(requester_id, responder_id)
1964        .await?;
1965
1966    // Update outgoing contact requests of requester
1967    let mut update = proto::UpdateContacts::default();
1968    update.outgoing_requests.push(responder_id.to_proto());
1969    for connection_id in session
1970        .connection_pool()
1971        .await
1972        .user_connection_ids(requester_id)
1973    {
1974        session.peer.send(connection_id, update.clone())?;
1975    }
1976
1977    // Update incoming contact requests of responder
1978    let mut update = proto::UpdateContacts::default();
1979    update
1980        .incoming_requests
1981        .push(proto::IncomingContactRequest {
1982            requester_id: requester_id.to_proto(),
1983            should_notify: true,
1984        });
1985    for connection_id in session
1986        .connection_pool()
1987        .await
1988        .user_connection_ids(responder_id)
1989    {
1990        session.peer.send(connection_id, update.clone())?;
1991    }
1992
1993    response.send(proto::Ack {})?;
1994    Ok(())
1995}
1996
1997async fn respond_to_contact_request(
1998    request: proto::RespondToContactRequest,
1999    response: Response<proto::RespondToContactRequest>,
2000    session: Session,
2001) -> Result<()> {
2002    let responder_id = session.user_id;
2003    let requester_id = UserId::from_proto(request.requester_id);
2004    let db = session.db().await;
2005    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2006        db.dismiss_contact_notification(responder_id, requester_id)
2007            .await?;
2008    } else {
2009        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2010
2011        db.respond_to_contact_request(responder_id, requester_id, accept)
2012            .await?;
2013        let requester_busy = db.is_user_busy(requester_id).await?;
2014        let responder_busy = db.is_user_busy(responder_id).await?;
2015
2016        let pool = session.connection_pool().await;
2017        // Update responder with new contact
2018        let mut update = proto::UpdateContacts::default();
2019        if accept {
2020            update
2021                .contacts
2022                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2023        }
2024        update
2025            .remove_incoming_requests
2026            .push(requester_id.to_proto());
2027        for connection_id in pool.user_connection_ids(responder_id) {
2028            session.peer.send(connection_id, update.clone())?;
2029        }
2030
2031        // Update requester with new contact
2032        let mut update = proto::UpdateContacts::default();
2033        if accept {
2034            update
2035                .contacts
2036                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2037        }
2038        update
2039            .remove_outgoing_requests
2040            .push(responder_id.to_proto());
2041        for connection_id in pool.user_connection_ids(requester_id) {
2042            session.peer.send(connection_id, update.clone())?;
2043        }
2044    }
2045
2046    response.send(proto::Ack {})?;
2047    Ok(())
2048}
2049
2050async fn remove_contact(
2051    request: proto::RemoveContact,
2052    response: Response<proto::RemoveContact>,
2053    session: Session,
2054) -> Result<()> {
2055    let requester_id = session.user_id;
2056    let responder_id = UserId::from_proto(request.user_id);
2057    let db = session.db().await;
2058    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2059
2060    let pool = session.connection_pool().await;
2061    // Update outgoing contact requests of requester
2062    let mut update = proto::UpdateContacts::default();
2063    if contact_accepted {
2064        update.remove_contacts.push(responder_id.to_proto());
2065    } else {
2066        update
2067            .remove_outgoing_requests
2068            .push(responder_id.to_proto());
2069    }
2070    for connection_id in pool.user_connection_ids(requester_id) {
2071        session.peer.send(connection_id, update.clone())?;
2072    }
2073
2074    // Update incoming contact requests of responder
2075    let mut update = proto::UpdateContacts::default();
2076    if contact_accepted {
2077        update.remove_contacts.push(requester_id.to_proto());
2078    } else {
2079        update
2080            .remove_incoming_requests
2081            .push(requester_id.to_proto());
2082    }
2083    for connection_id in pool.user_connection_ids(responder_id) {
2084        session.peer.send(connection_id, update.clone())?;
2085    }
2086
2087    response.send(proto::Ack {})?;
2088    Ok(())
2089}
2090
2091async fn create_channel(
2092    request: proto::CreateChannel,
2093    response: Response<proto::CreateChannel>,
2094    session: Session,
2095) -> Result<()> {
2096    let db = session.db().await;
2097    let id = db
2098        .create_channel(
2099            &request.name,
2100            request.parent_id.map(|id| ChannelId::from_proto(id)),
2101            session.user_id,
2102        )
2103        .await?;
2104
2105    let mut update = proto::UpdateChannels::default();
2106    update.channels.push(proto::Channel {
2107        id: id.to_proto(),
2108        name: request.name,
2109        parent_id: request.parent_id,
2110    });
2111    session.peer.send(session.connection_id, update)?;
2112    response.send(proto::CreateChannelResponse {
2113        channel_id: id.to_proto(),
2114    })?;
2115
2116    Ok(())
2117}
2118
2119async fn invite_channel_member(
2120    request: proto::InviteChannelMember,
2121    response: Response<proto::InviteChannelMember>,
2122    session: Session,
2123) -> Result<()> {
2124    let db = session.db().await;
2125    let channel_id = ChannelId::from_proto(request.channel_id);
2126    let channel = db.get_channel(channel_id).await?;
2127    let invitee_id = UserId::from_proto(request.user_id);
2128    db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2129        .await?;
2130
2131    let mut update = proto::UpdateChannels::default();
2132    update.channel_invitations.push(proto::Channel {
2133        id: channel.id.to_proto(),
2134        name: channel.name,
2135        parent_id: None,
2136    });
2137    for connection_id in session
2138        .connection_pool()
2139        .await
2140        .user_connection_ids(invitee_id)
2141    {
2142        session.peer.send(connection_id, update.clone())?;
2143    }
2144
2145    response.send(proto::Ack {})?;
2146    Ok(())
2147}
2148
2149async fn remove_channel_member(
2150    request: proto::RemoveChannelMember,
2151    response: Response<proto::RemoveChannelMember>,
2152    session: Session,
2153) -> Result<()> {
2154    Ok(())
2155}
2156
2157async fn respond_to_channel_invite(
2158    request: proto::RespondToChannelInvite,
2159    response: Response<proto::RespondToChannelInvite>,
2160    session: Session,
2161) -> Result<()> {
2162    let db = session.db().await;
2163    let channel_id = ChannelId::from_proto(request.channel_id);
2164    let channel = db.get_channel(channel_id).await?;
2165    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2166        .await?;
2167
2168    let mut update = proto::UpdateChannels::default();
2169    update
2170        .remove_channel_invitations
2171        .push(channel_id.to_proto());
2172    if request.accept {
2173        update.channels.push(proto::Channel {
2174            id: channel.id.to_proto(),
2175            name: channel.name,
2176            parent_id: None,
2177        });
2178    }
2179    session.peer.send(session.connection_id, update)?;
2180    response.send(proto::Ack {})?;
2181
2182    Ok(())
2183}
2184
2185async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2186    let project_id = ProjectId::from_proto(request.project_id);
2187    let project_connection_ids = session
2188        .db()
2189        .await
2190        .project_connection_ids(project_id, session.connection_id)
2191        .await?;
2192    broadcast(
2193        Some(session.connection_id),
2194        project_connection_ids.iter().copied(),
2195        |connection_id| {
2196            session
2197                .peer
2198                .forward_send(session.connection_id, connection_id, request.clone())
2199        },
2200    );
2201    Ok(())
2202}
2203
2204async fn get_private_user_info(
2205    _request: proto::GetPrivateUserInfo,
2206    response: Response<proto::GetPrivateUserInfo>,
2207    session: Session,
2208) -> Result<()> {
2209    let metrics_id = session
2210        .db()
2211        .await
2212        .get_user_metrics_id(session.user_id)
2213        .await?;
2214    let user = session
2215        .db()
2216        .await
2217        .get_user_by_id(session.user_id)
2218        .await?
2219        .ok_or_else(|| anyhow!("user not found"))?;
2220    response.send(proto::GetPrivateUserInfoResponse {
2221        metrics_id,
2222        staff: user.admin,
2223    })?;
2224    Ok(())
2225}
2226
2227fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2228    match message {
2229        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2230        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2231        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2232        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2233        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2234            code: frame.code.into(),
2235            reason: frame.reason,
2236        })),
2237    }
2238}
2239
2240fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2241    match message {
2242        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2243        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2244        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2245        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2246        AxumMessage::Close(frame) => {
2247            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2248                code: frame.code.into(),
2249                reason: frame.reason,
2250            }))
2251        }
2252    }
2253}
2254
2255fn build_initial_contacts_update(
2256    contacts: Vec<db::Contact>,
2257    pool: &ConnectionPool,
2258) -> proto::UpdateContacts {
2259    let mut update = proto::UpdateContacts::default();
2260
2261    for contact in contacts {
2262        match contact {
2263            db::Contact::Accepted {
2264                user_id,
2265                should_notify,
2266                busy,
2267            } => {
2268                update
2269                    .contacts
2270                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2271            }
2272            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2273            db::Contact::Incoming {
2274                user_id,
2275                should_notify,
2276            } => update
2277                .incoming_requests
2278                .push(proto::IncomingContactRequest {
2279                    requester_id: user_id.to_proto(),
2280                    should_notify,
2281                }),
2282        }
2283    }
2284
2285    update
2286}
2287
2288fn contact_for_user(
2289    user_id: UserId,
2290    should_notify: bool,
2291    busy: bool,
2292    pool: &ConnectionPool,
2293) -> proto::Contact {
2294    proto::Contact {
2295        user_id: user_id.to_proto(),
2296        online: pool.is_user_online(user_id),
2297        busy,
2298        should_notify,
2299    }
2300}
2301
2302fn room_updated(room: &proto::Room, peer: &Peer) {
2303    broadcast(
2304        None,
2305        room.participants
2306            .iter()
2307            .filter_map(|participant| Some(participant.peer_id?.into())),
2308        |peer_id| {
2309            peer.send(
2310                peer_id.into(),
2311                proto::RoomUpdated {
2312                    room: Some(room.clone()),
2313                },
2314            )
2315        },
2316    );
2317}
2318
2319async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2320    let db = session.db().await;
2321    let contacts = db.get_contacts(user_id).await?;
2322    let busy = db.is_user_busy(user_id).await?;
2323
2324    let pool = session.connection_pool().await;
2325    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2326    for contact in contacts {
2327        if let db::Contact::Accepted {
2328            user_id: contact_user_id,
2329            ..
2330        } = contact
2331        {
2332            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2333                session
2334                    .peer
2335                    .send(
2336                        contact_conn_id,
2337                        proto::UpdateContacts {
2338                            contacts: vec![updated_contact.clone()],
2339                            remove_contacts: Default::default(),
2340                            incoming_requests: Default::default(),
2341                            remove_incoming_requests: Default::default(),
2342                            outgoing_requests: Default::default(),
2343                            remove_outgoing_requests: Default::default(),
2344                        },
2345                    )
2346                    .trace_err();
2347            }
2348        }
2349    }
2350    Ok(())
2351}
2352
2353async fn leave_room_for_session(session: &Session) -> Result<()> {
2354    let mut contacts_to_update = HashSet::default();
2355
2356    let room_id;
2357    let canceled_calls_to_user_ids;
2358    let live_kit_room;
2359    let delete_live_kit_room;
2360    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2361        contacts_to_update.insert(session.user_id);
2362
2363        for project in left_room.left_projects.values() {
2364            project_left(project, session);
2365        }
2366
2367        room_updated(&left_room.room, &session.peer);
2368        room_id = RoomId::from_proto(left_room.room.id);
2369        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2370        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2371        delete_live_kit_room = left_room.room.participants.is_empty();
2372    } else {
2373        return Ok(());
2374    }
2375
2376    {
2377        let pool = session.connection_pool().await;
2378        for canceled_user_id in canceled_calls_to_user_ids {
2379            for connection_id in pool.user_connection_ids(canceled_user_id) {
2380                session
2381                    .peer
2382                    .send(
2383                        connection_id,
2384                        proto::CallCanceled {
2385                            room_id: room_id.to_proto(),
2386                        },
2387                    )
2388                    .trace_err();
2389            }
2390            contacts_to_update.insert(canceled_user_id);
2391        }
2392    }
2393
2394    for contact_user_id in contacts_to_update {
2395        update_user_contacts(contact_user_id, &session).await?;
2396    }
2397
2398    if let Some(live_kit) = session.live_kit_client.as_ref() {
2399        live_kit
2400            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2401            .await
2402            .trace_err();
2403
2404        if delete_live_kit_room {
2405            live_kit.delete_room(live_kit_room).await.trace_err();
2406        }
2407    }
2408
2409    Ok(())
2410}
2411
2412fn project_left(project: &db::LeftProject, session: &Session) {
2413    for connection_id in &project.connection_ids {
2414        if project.host_user_id == session.user_id {
2415            session
2416                .peer
2417                .send(
2418                    *connection_id,
2419                    proto::UnshareProject {
2420                        project_id: project.id.to_proto(),
2421                    },
2422                )
2423                .trace_err();
2424        } else {
2425            session
2426                .peer
2427                .send(
2428                    *connection_id,
2429                    proto::RemoveProjectCollaborator {
2430                        project_id: project.id.to_proto(),
2431                        peer_id: Some(session.connection_id.into()),
2432                    },
2433                )
2434                .trace_err();
2435        }
2436    }
2437}
2438
2439pub trait ResultExt {
2440    type Ok;
2441
2442    fn trace_err(self) -> Option<Self::Ok>;
2443}
2444
2445impl<T, E> ResultExt for Result<T, E>
2446where
2447    E: std::fmt::Debug,
2448{
2449    type Ok = T;
2450
2451    fn trace_err(self) -> Option<T> {
2452        match self {
2453            Ok(value) => Some(value),
2454            Err(error) => {
2455                tracing::error!("{:?}", error);
2456                None
2457            }
2458        }
2459    }
2460}