rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Config, Error, RateLimit, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use http_client::HttpClient;
  37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
  38use reqwest_client::ReqwestClient;
  39use rpc::proto::split_repository_update;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{MutexGuard, Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 311            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 312            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 313            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 320            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(
 325                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 326            )
 327            .add_request_handler(
 328                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 329            )
 330            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 331            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 332            .add_request_handler(
 333                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 334            )
 335            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 358            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 359            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 360            .add_message_handler(create_buffer_for_peer)
 361            .add_request_handler(update_buffer)
 362            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 363            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 368            .add_request_handler(get_users)
 369            .add_request_handler(fuzzy_search_users)
 370            .add_request_handler(request_contact)
 371            .add_request_handler(remove_contact)
 372            .add_request_handler(respond_to_contact_request)
 373            .add_message_handler(subscribe_to_channels)
 374            .add_request_handler(create_channel)
 375            .add_request_handler(delete_channel)
 376            .add_request_handler(invite_channel_member)
 377            .add_request_handler(remove_channel_member)
 378            .add_request_handler(set_channel_member_role)
 379            .add_request_handler(set_channel_visibility)
 380            .add_request_handler(rename_channel)
 381            .add_request_handler(join_channel_buffer)
 382            .add_request_handler(leave_channel_buffer)
 383            .add_message_handler(update_channel_buffer)
 384            .add_request_handler(rejoin_channel_buffers)
 385            .add_request_handler(get_channel_members)
 386            .add_request_handler(respond_to_channel_invite)
 387            .add_request_handler(join_channel)
 388            .add_request_handler(join_channel_chat)
 389            .add_message_handler(leave_channel_chat)
 390            .add_request_handler(send_channel_message)
 391            .add_request_handler(remove_channel_message)
 392            .add_request_handler(update_channel_message)
 393            .add_request_handler(get_channel_messages)
 394            .add_request_handler(get_channel_messages_by_id)
 395            .add_request_handler(get_notifications)
 396            .add_request_handler(mark_notification_as_read)
 397            .add_request_handler(move_channel)
 398            .add_request_handler(follow)
 399            .add_message_handler(unfollow)
 400            .add_message_handler(update_followers)
 401            .add_request_handler(get_private_user_info)
 402            .add_request_handler(get_llm_api_token)
 403            .add_request_handler(accept_terms_of_service)
 404            .add_message_handler(acknowledge_channel_message)
 405            .add_message_handler(acknowledge_buffer_version)
 406            .add_request_handler(get_supermaven_api_key)
 407            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 409            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 410            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 411            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 414            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 416            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 417            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 418            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 419            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 420            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 421            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 422            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 424            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 425            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 426            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 428            .add_message_handler(update_context)
 429            .add_request_handler({
 430                let app_state = app_state.clone();
 431                move |request, response, session| {
 432                    let app_state = app_state.clone();
 433                    async move {
 434                        count_language_model_tokens(request, response, session, &app_state.config)
 435                            .await
 436                    }
 437                }
 438            })
 439            .add_request_handler(get_cached_embeddings)
 440            .add_request_handler({
 441                let app_state = app_state.clone();
 442                move |request, response, session| {
 443                    compute_embeddings(
 444                        request,
 445                        response,
 446                        session,
 447                        app_state.config.openai_api_key.clone(),
 448                    )
 449                }
 450            });
 451
 452        Arc::new(server)
 453    }
 454
 455    pub async fn start(&self) -> Result<()> {
 456        let server_id = *self.id.lock();
 457        let app_state = self.app_state.clone();
 458        let peer = self.peer.clone();
 459        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 460        let pool = self.connection_pool.clone();
 461        let livekit_client = self.app_state.livekit_client.clone();
 462
 463        let span = info_span!("start server");
 464        self.app_state.executor.spawn_detached(
 465            async move {
 466                tracing::info!("waiting for cleanup timeout");
 467                timeout.await;
 468                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 469                if let Some((room_ids, channel_ids)) = app_state
 470                    .db
 471                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 472                    .await
 473                    .trace_err()
 474                {
 475                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 476                    tracing::info!(
 477                        stale_channel_buffer_count = channel_ids.len(),
 478                        "retrieved stale channel buffers"
 479                    );
 480
 481                    for channel_id in channel_ids {
 482                        if let Some(refreshed_channel_buffer) = app_state
 483                            .db
 484                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 485                            .await
 486                            .trace_err()
 487                        {
 488                            for connection_id in refreshed_channel_buffer.connection_ids {
 489                                peer.send(
 490                                    connection_id,
 491                                    proto::UpdateChannelBufferCollaborators {
 492                                        channel_id: channel_id.to_proto(),
 493                                        collaborators: refreshed_channel_buffer
 494                                            .collaborators
 495                                            .clone(),
 496                                    },
 497                                )
 498                                .trace_err();
 499                            }
 500                        }
 501                    }
 502
 503                    for room_id in room_ids {
 504                        let mut contacts_to_update = HashSet::default();
 505                        let mut canceled_calls_to_user_ids = Vec::new();
 506                        let mut livekit_room = String::new();
 507                        let mut delete_livekit_room = false;
 508
 509                        if let Some(mut refreshed_room) = app_state
 510                            .db
 511                            .clear_stale_room_participants(room_id, server_id)
 512                            .await
 513                            .trace_err()
 514                        {
 515                            tracing::info!(
 516                                room_id = room_id.0,
 517                                new_participant_count = refreshed_room.room.participants.len(),
 518                                "refreshed room"
 519                            );
 520                            room_updated(&refreshed_room.room, &peer);
 521                            if let Some(channel) = refreshed_room.channel.as_ref() {
 522                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 523                            }
 524                            contacts_to_update
 525                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 526                            contacts_to_update
 527                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 528                            canceled_calls_to_user_ids =
 529                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 530                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 531                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 532                        }
 533
 534                        {
 535                            let pool = pool.lock();
 536                            for canceled_user_id in canceled_calls_to_user_ids {
 537                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 538                                    peer.send(
 539                                        connection_id,
 540                                        proto::CallCanceled {
 541                                            room_id: room_id.to_proto(),
 542                                        },
 543                                    )
 544                                    .trace_err();
 545                                }
 546                            }
 547                        }
 548
 549                        for user_id in contacts_to_update {
 550                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 551                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 552                            if let Some((busy, contacts)) = busy.zip(contacts) {
 553                                let pool = pool.lock();
 554                                let updated_contact = contact_for_user(user_id, busy, &pool);
 555                                for contact in contacts {
 556                                    if let db::Contact::Accepted {
 557                                        user_id: contact_user_id,
 558                                        ..
 559                                    } = contact
 560                                    {
 561                                        for contact_conn_id in
 562                                            pool.user_connection_ids(contact_user_id)
 563                                        {
 564                                            peer.send(
 565                                                contact_conn_id,
 566                                                proto::UpdateContacts {
 567                                                    contacts: vec![updated_contact.clone()],
 568                                                    remove_contacts: Default::default(),
 569                                                    incoming_requests: Default::default(),
 570                                                    remove_incoming_requests: Default::default(),
 571                                                    outgoing_requests: Default::default(),
 572                                                    remove_outgoing_requests: Default::default(),
 573                                                },
 574                                            )
 575                                            .trace_err();
 576                                        }
 577                                    }
 578                                }
 579                            }
 580                        }
 581
 582                        if let Some(live_kit) = livekit_client.as_ref() {
 583                            if delete_livekit_room {
 584                                live_kit.delete_room(livekit_room).await.trace_err();
 585                            }
 586                        }
 587                    }
 588                }
 589
 590                app_state
 591                    .db
 592                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 593                    .await
 594                    .trace_err();
 595            }
 596            .instrument(span),
 597        );
 598        Ok(())
 599    }
 600
 601    pub fn teardown(&self) {
 602        self.peer.teardown();
 603        self.connection_pool.lock().reset();
 604        let _ = self.teardown.send(true);
 605    }
 606
 607    #[cfg(test)]
 608    pub fn reset(&self, id: ServerId) {
 609        self.teardown();
 610        *self.id.lock() = id;
 611        self.peer.reset(id.0 as u32);
 612        let _ = self.teardown.send(false);
 613    }
 614
 615    #[cfg(test)]
 616    pub fn id(&self) -> ServerId {
 617        *self.id.lock()
 618    }
 619
 620    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 621    where
 622        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 623        Fut: 'static + Send + Future<Output = Result<()>>,
 624        M: EnvelopedMessage,
 625    {
 626        let prev_handler = self.handlers.insert(
 627            TypeId::of::<M>(),
 628            Box::new(move |envelope, session| {
 629                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 630                let received_at = envelope.received_at;
 631                tracing::info!("message received");
 632                let start_time = Instant::now();
 633                let future = (handler)(*envelope, session);
 634                async move {
 635                    let result = future.await;
 636                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 637                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 638                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 639                    let payload_type = M::NAME;
 640
 641                    match result {
 642                        Err(error) => {
 643                            tracing::error!(
 644                                ?error,
 645                                total_duration_ms,
 646                                processing_duration_ms,
 647                                queue_duration_ms,
 648                                payload_type,
 649                                "error handling message"
 650                            )
 651                        }
 652                        Ok(()) => tracing::info!(
 653                            total_duration_ms,
 654                            processing_duration_ms,
 655                            queue_duration_ms,
 656                            "finished handling message"
 657                        ),
 658                    }
 659                }
 660                .boxed()
 661            }),
 662        );
 663        if prev_handler.is_some() {
 664            panic!("registered a handler for the same message twice");
 665        }
 666        self
 667    }
 668
 669    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 670    where
 671        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 672        Fut: 'static + Send + Future<Output = Result<()>>,
 673        M: EnvelopedMessage,
 674    {
 675        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 676        self
 677    }
 678
 679    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 680    where
 681        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 682        Fut: Send + Future<Output = Result<()>>,
 683        M: RequestMessage,
 684    {
 685        let handler = Arc::new(handler);
 686        self.add_handler(move |envelope, session| {
 687            let receipt = envelope.receipt();
 688            let handler = handler.clone();
 689            async move {
 690                let peer = session.peer.clone();
 691                let responded = Arc::new(AtomicBool::default());
 692                let response = Response {
 693                    peer: peer.clone(),
 694                    responded: responded.clone(),
 695                    receipt,
 696                };
 697                match (handler)(envelope.payload, response, session).await {
 698                    Ok(()) => {
 699                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 700                            Ok(())
 701                        } else {
 702                            Err(anyhow!("handler did not send a response"))?
 703                        }
 704                    }
 705                    Err(error) => {
 706                        let proto_err = match &error {
 707                            Error::Internal(err) => err.to_proto(),
 708                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 709                        };
 710                        peer.respond_with_error(receipt, proto_err)?;
 711                        Err(error)
 712                    }
 713                }
 714            }
 715        })
 716    }
 717
 718    pub fn handle_connection(
 719        self: &Arc<Self>,
 720        connection: Connection,
 721        address: String,
 722        principal: Principal,
 723        zed_version: ZedVersion,
 724        geoip_country_code: Option<String>,
 725        system_id: Option<String>,
 726        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 727        executor: Executor,
 728    ) -> impl Future<Output = ()> + use<> {
 729        let this = self.clone();
 730        let span = info_span!("handle connection", %address,
 731            connection_id=field::Empty,
 732            user_id=field::Empty,
 733            login=field::Empty,
 734            impersonator=field::Empty,
 735            geoip_country_code=field::Empty
 736        );
 737        principal.update_span(&span);
 738        if let Some(country_code) = geoip_country_code.as_ref() {
 739            span.record("geoip_country_code", country_code);
 740        }
 741
 742        let mut teardown = self.teardown.subscribe();
 743        async move {
 744            if *teardown.borrow() {
 745                tracing::error!("server is tearing down");
 746                return
 747            }
 748            let (connection_id, handle_io, mut incoming_rx) = this
 749                .peer
 750                .add_connection(connection, {
 751                    let executor = executor.clone();
 752                    move |duration| executor.sleep(duration)
 753                });
 754            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 755
 756            tracing::info!("connection opened");
 757
 758            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 759            let http_client = match ReqwestClient::user_agent(&user_agent) {
 760                Ok(http_client) => Arc::new(http_client),
 761                Err(error) => {
 762                    tracing::error!(?error, "failed to create HTTP client");
 763                    return;
 764                }
 765            };
 766
 767            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 768                    supermaven_admin_api_key.to_string(),
 769                    http_client.clone(),
 770                )));
 771
 772            let session = Session {
 773                principal: principal.clone(),
 774                connection_id,
 775                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 776                peer: this.peer.clone(),
 777                connection_pool: this.connection_pool.clone(),
 778                app_state: this.app_state.clone(),
 779                http_client,
 780                geoip_country_code,
 781                system_id,
 782                _executor: executor.clone(),
 783                supermaven_client,
 784            };
 785
 786            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 787                tracing::error!(?error, "failed to send initial client update");
 788                return;
 789            }
 790
 791            let handle_io = handle_io.fuse();
 792            futures::pin_mut!(handle_io);
 793
 794            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 795            // This prevents deadlocks when e.g., client A performs a request to client B and
 796            // client B performs a request to client A. If both clients stop processing further
 797            // messages until their respective request completes, they won't have a chance to
 798            // respond to the other client's request and cause a deadlock.
 799            //
 800            // This arrangement ensures we will attempt to process earlier messages first, but fall
 801            // back to processing messages arrived later in the spirit of making progress.
 802            let mut foreground_message_handlers = FuturesUnordered::new();
 803            let concurrent_handlers = Arc::new(Semaphore::new(256));
 804            loop {
 805                let next_message = async {
 806                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 807                    let message = incoming_rx.next().await;
 808                    (permit, message)
 809                }.fuse();
 810                futures::pin_mut!(next_message);
 811                futures::select_biased! {
 812                    _ = teardown.changed().fuse() => return,
 813                    result = handle_io => {
 814                        if let Err(error) = result {
 815                            tracing::error!(?error, "error handling I/O");
 816                        }
 817                        break;
 818                    }
 819                    _ = foreground_message_handlers.next() => {}
 820                    next_message = next_message => {
 821                        let (permit, message) = next_message;
 822                        if let Some(message) = message {
 823                            let type_name = message.payload_type_name();
 824                            // note: we copy all the fields from the parent span so we can query them in the logs.
 825                            // (https://github.com/tokio-rs/tracing/issues/2670).
 826                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 827                                user_id=field::Empty,
 828                                login=field::Empty,
 829                                impersonator=field::Empty,
 830                            );
 831                            principal.update_span(&span);
 832                            let span_enter = span.enter();
 833                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 834                                let is_background = message.is_background();
 835                                let handle_message = (handler)(message, session.clone());
 836                                drop(span_enter);
 837
 838                                let handle_message = async move {
 839                                    handle_message.await;
 840                                    drop(permit);
 841                                }.instrument(span);
 842                                if is_background {
 843                                    executor.spawn_detached(handle_message);
 844                                } else {
 845                                    foreground_message_handlers.push(handle_message);
 846                                }
 847                            } else {
 848                                tracing::error!("no message handler");
 849                            }
 850                        } else {
 851                            tracing::info!("connection closed");
 852                            break;
 853                        }
 854                    }
 855                }
 856            }
 857
 858            drop(foreground_message_handlers);
 859            tracing::info!("signing out");
 860            if let Err(error) = connection_lost(session, teardown, executor).await {
 861                tracing::error!(?error, "error signing out");
 862            }
 863
 864        }.instrument(span)
 865    }
 866
 867    async fn send_initial_client_update(
 868        &self,
 869        connection_id: ConnectionId,
 870        principal: &Principal,
 871        zed_version: ZedVersion,
 872        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 873        session: &Session,
 874    ) -> Result<()> {
 875        self.peer.send(
 876            connection_id,
 877            proto::Hello {
 878                peer_id: Some(connection_id.into()),
 879            },
 880        )?;
 881        tracing::info!("sent hello message");
 882        if let Some(send_connection_id) = send_connection_id.take() {
 883            let _ = send_connection_id.send(connection_id);
 884        }
 885
 886        match principal {
 887            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 888                if !user.connected_once {
 889                    self.peer.send(connection_id, proto::ShowContacts {})?;
 890                    self.app_state
 891                        .db
 892                        .set_user_connected_once(user.id, true)
 893                        .await?;
 894                }
 895
 896                update_user_plan(user.id, session).await?;
 897
 898                let contacts = self.app_state.db.get_contacts(user.id).await?;
 899
 900                {
 901                    let mut pool = self.connection_pool.lock();
 902                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 903                    self.peer.send(
 904                        connection_id,
 905                        build_initial_contacts_update(contacts, &pool),
 906                    )?;
 907                }
 908
 909                if should_auto_subscribe_to_channels(zed_version) {
 910                    subscribe_user_to_channels(user.id, session).await?;
 911                }
 912
 913                if let Some(incoming_call) =
 914                    self.app_state.db.incoming_call_for_user(user.id).await?
 915                {
 916                    self.peer.send(connection_id, incoming_call)?;
 917                }
 918
 919                update_user_contacts(user.id, session).await?;
 920            }
 921        }
 922
 923        Ok(())
 924    }
 925
 926    pub async fn invite_code_redeemed(
 927        self: &Arc<Self>,
 928        inviter_id: UserId,
 929        invitee_id: UserId,
 930    ) -> Result<()> {
 931        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 932            if let Some(code) = &user.invite_code {
 933                let pool = self.connection_pool.lock();
 934                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 935                for connection_id in pool.user_connection_ids(inviter_id) {
 936                    self.peer.send(
 937                        connection_id,
 938                        proto::UpdateContacts {
 939                            contacts: vec![invitee_contact.clone()],
 940                            ..Default::default()
 941                        },
 942                    )?;
 943                    self.peer.send(
 944                        connection_id,
 945                        proto::UpdateInviteInfo {
 946                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 947                            count: user.invite_count as u32,
 948                        },
 949                    )?;
 950                }
 951            }
 952        }
 953        Ok(())
 954    }
 955
 956    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 957        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 958            if let Some(invite_code) = &user.invite_code {
 959                let pool = self.connection_pool.lock();
 960                for connection_id in pool.user_connection_ids(user_id) {
 961                    self.peer.send(
 962                        connection_id,
 963                        proto::UpdateInviteInfo {
 964                            url: format!(
 965                                "{}{}",
 966                                self.app_state.config.invite_link_prefix, invite_code
 967                            ),
 968                            count: user.invite_count as u32,
 969                        },
 970                    )?;
 971                }
 972            }
 973        }
 974        Ok(())
 975    }
 976
 977    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 978        let pool = self.connection_pool.lock();
 979        for connection_id in pool.user_connection_ids(user_id) {
 980            self.peer
 981                .send(connection_id, proto::RefreshLlmToken {})
 982                .trace_err();
 983        }
 984    }
 985
 986    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 987        ServerSnapshot {
 988            connection_pool: ConnectionPoolGuard {
 989                guard: self.connection_pool.lock(),
 990                _not_send: PhantomData,
 991            },
 992            peer: &self.peer,
 993        }
 994    }
 995}
 996
 997impl Deref for ConnectionPoolGuard<'_> {
 998    type Target = ConnectionPool;
 999
1000    fn deref(&self) -> &Self::Target {
1001        &self.guard
1002    }
1003}
1004
1005impl DerefMut for ConnectionPoolGuard<'_> {
1006    fn deref_mut(&mut self) -> &mut Self::Target {
1007        &mut self.guard
1008    }
1009}
1010
1011impl Drop for ConnectionPoolGuard<'_> {
1012    fn drop(&mut self) {
1013        #[cfg(test)]
1014        self.check_invariants();
1015    }
1016}
1017
1018fn broadcast<F>(
1019    sender_id: Option<ConnectionId>,
1020    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1021    mut f: F,
1022) where
1023    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1024{
1025    for receiver_id in receiver_ids {
1026        if Some(receiver_id) != sender_id {
1027            if let Err(error) = f(receiver_id) {
1028                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1029            }
1030        }
1031    }
1032}
1033
1034pub struct ProtocolVersion(u32);
1035
1036impl Header for ProtocolVersion {
1037    fn name() -> &'static HeaderName {
1038        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1039        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1040    }
1041
1042    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1043    where
1044        Self: Sized,
1045        I: Iterator<Item = &'i axum::http::HeaderValue>,
1046    {
1047        let version = values
1048            .next()
1049            .ok_or_else(axum::headers::Error::invalid)?
1050            .to_str()
1051            .map_err(|_| axum::headers::Error::invalid())?
1052            .parse()
1053            .map_err(|_| axum::headers::Error::invalid())?;
1054        Ok(Self(version))
1055    }
1056
1057    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1058        values.extend([self.0.to_string().parse().unwrap()]);
1059    }
1060}
1061
1062pub struct AppVersionHeader(SemanticVersion);
1063impl Header for AppVersionHeader {
1064    fn name() -> &'static HeaderName {
1065        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1066        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1067    }
1068
1069    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1070    where
1071        Self: Sized,
1072        I: Iterator<Item = &'i axum::http::HeaderValue>,
1073    {
1074        let version = values
1075            .next()
1076            .ok_or_else(axum::headers::Error::invalid)?
1077            .to_str()
1078            .map_err(|_| axum::headers::Error::invalid())?
1079            .parse()
1080            .map_err(|_| axum::headers::Error::invalid())?;
1081        Ok(Self(version))
1082    }
1083
1084    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1085        values.extend([self.0.to_string().parse().unwrap()]);
1086    }
1087}
1088
1089pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1090    Router::new()
1091        .route("/rpc", get(handle_websocket_request))
1092        .layer(
1093            ServiceBuilder::new()
1094                .layer(Extension(server.app_state.clone()))
1095                .layer(middleware::from_fn(auth::validate_header)),
1096        )
1097        .route("/metrics", get(handle_metrics))
1098        .layer(Extension(server))
1099}
1100
1101pub async fn handle_websocket_request(
1102    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1103    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1104    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105    Extension(server): Extension<Arc<Server>>,
1106    Extension(principal): Extension<Principal>,
1107    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1108    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1109    ws: WebSocketUpgrade,
1110) -> axum::response::Response {
1111    if protocol_version != rpc::PROTOCOL_VERSION {
1112        return (
1113            StatusCode::UPGRADE_REQUIRED,
1114            "client must be upgraded".to_string(),
1115        )
1116            .into_response();
1117    }
1118
1119    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1120        return (
1121            StatusCode::UPGRADE_REQUIRED,
1122            "no version header found".to_string(),
1123        )
1124            .into_response();
1125    };
1126
1127    if !version.can_collaborate() {
1128        return (
1129            StatusCode::UPGRADE_REQUIRED,
1130            "client must be upgraded".to_string(),
1131        )
1132            .into_response();
1133    }
1134
1135    let socket_address = socket_address.to_string();
1136    ws.on_upgrade(move |socket| {
1137        let socket = socket
1138            .map_ok(to_tungstenite_message)
1139            .err_into()
1140            .with(|message| async move { to_axum_message(message) });
1141        let connection = Connection::new(Box::pin(socket));
1142        async move {
1143            server
1144                .handle_connection(
1145                    connection,
1146                    socket_address,
1147                    principal,
1148                    version,
1149                    country_code_header.map(|header| header.to_string()),
1150                    system_id_header.map(|header| header.to_string()),
1151                    None,
1152                    Executor::Production,
1153                )
1154                .await;
1155        }
1156    })
1157}
1158
1159pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1160    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1161    let connections_metric = CONNECTIONS_METRIC
1162        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1163
1164    let connections = server
1165        .connection_pool
1166        .lock()
1167        .connections()
1168        .filter(|connection| !connection.admin)
1169        .count();
1170    connections_metric.set(connections as _);
1171
1172    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1173    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1174        register_int_gauge!(
1175            "shared_projects",
1176            "number of open projects with one or more guests"
1177        )
1178        .unwrap()
1179    });
1180
1181    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1182    shared_projects_metric.set(shared_projects as _);
1183
1184    let encoder = prometheus::TextEncoder::new();
1185    let metric_families = prometheus::gather();
1186    let encoded_metrics = encoder
1187        .encode_to_string(&metric_families)
1188        .map_err(|err| anyhow!("{}", err))?;
1189    Ok(encoded_metrics)
1190}
1191
1192#[instrument(err, skip(executor))]
1193async fn connection_lost(
1194    session: Session,
1195    mut teardown: watch::Receiver<bool>,
1196    executor: Executor,
1197) -> Result<()> {
1198    session.peer.disconnect(session.connection_id);
1199    session
1200        .connection_pool()
1201        .await
1202        .remove_connection(session.connection_id)?;
1203
1204    session
1205        .db()
1206        .await
1207        .connection_lost(session.connection_id)
1208        .await
1209        .trace_err();
1210
1211    futures::select_biased! {
1212        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1213
1214            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1215            leave_room_for_session(&session, session.connection_id).await.trace_err();
1216            leave_channel_buffers_for_session(&session)
1217                .await
1218                .trace_err();
1219
1220            if !session
1221                .connection_pool()
1222                .await
1223                .is_user_online(session.user_id())
1224            {
1225                let db = session.db().await;
1226                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1227                    room_updated(&room, &session.peer);
1228                }
1229            }
1230
1231            update_user_contacts(session.user_id(), &session).await?;
1232        },
1233        _ = teardown.changed().fuse() => {}
1234    }
1235
1236    Ok(())
1237}
1238
1239/// Acknowledges a ping from a client, used to keep the connection alive.
1240async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1241    response.send(proto::Ack {})?;
1242    Ok(())
1243}
1244
1245/// Creates a new room for calling (outside of channels)
1246async fn create_room(
1247    _request: proto::CreateRoom,
1248    response: Response<proto::CreateRoom>,
1249    session: Session,
1250) -> Result<()> {
1251    let livekit_room = nanoid::nanoid!(30);
1252
1253    let live_kit_connection_info = util::maybe!(async {
1254        let live_kit = session.app_state.livekit_client.as_ref();
1255        let live_kit = live_kit?;
1256        let user_id = session.user_id().to_string();
1257
1258        let token = live_kit
1259            .room_token(&livekit_room, &user_id.to_string())
1260            .trace_err()?;
1261
1262        Some(proto::LiveKitConnectionInfo {
1263            server_url: live_kit.url().into(),
1264            token,
1265            can_publish: true,
1266        })
1267    })
1268    .await;
1269
1270    let room = session
1271        .db()
1272        .await
1273        .create_room(session.user_id(), session.connection_id, &livekit_room)
1274        .await?;
1275
1276    response.send(proto::CreateRoomResponse {
1277        room: Some(room.clone()),
1278        live_kit_connection_info,
1279    })?;
1280
1281    update_user_contacts(session.user_id(), &session).await?;
1282    Ok(())
1283}
1284
1285/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1286async fn join_room(
1287    request: proto::JoinRoom,
1288    response: Response<proto::JoinRoom>,
1289    session: Session,
1290) -> Result<()> {
1291    let room_id = RoomId::from_proto(request.id);
1292
1293    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1294
1295    if let Some(channel_id) = channel_id {
1296        return join_channel_internal(channel_id, Box::new(response), session).await;
1297    }
1298
1299    let joined_room = {
1300        let room = session
1301            .db()
1302            .await
1303            .join_room(room_id, session.user_id(), session.connection_id)
1304            .await?;
1305        room_updated(&room.room, &session.peer);
1306        room.into_inner()
1307    };
1308
1309    for connection_id in session
1310        .connection_pool()
1311        .await
1312        .user_connection_ids(session.user_id())
1313    {
1314        session
1315            .peer
1316            .send(
1317                connection_id,
1318                proto::CallCanceled {
1319                    room_id: room_id.to_proto(),
1320                },
1321            )
1322            .trace_err();
1323    }
1324
1325    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1326    {
1327        live_kit
1328            .room_token(
1329                &joined_room.room.livekit_room,
1330                &session.user_id().to_string(),
1331            )
1332            .trace_err()
1333            .map(|token| proto::LiveKitConnectionInfo {
1334                server_url: live_kit.url().into(),
1335                token,
1336                can_publish: true,
1337            })
1338    } else {
1339        None
1340    };
1341
1342    response.send(proto::JoinRoomResponse {
1343        room: Some(joined_room.room),
1344        channel_id: None,
1345        live_kit_connection_info,
1346    })?;
1347
1348    update_user_contacts(session.user_id(), &session).await?;
1349    Ok(())
1350}
1351
1352/// Rejoin room is used to reconnect to a room after connection errors.
1353async fn rejoin_room(
1354    request: proto::RejoinRoom,
1355    response: Response<proto::RejoinRoom>,
1356    session: Session,
1357) -> Result<()> {
1358    let room;
1359    let channel;
1360    {
1361        let mut rejoined_room = session
1362            .db()
1363            .await
1364            .rejoin_room(request, session.user_id(), session.connection_id)
1365            .await?;
1366
1367        response.send(proto::RejoinRoomResponse {
1368            room: Some(rejoined_room.room.clone()),
1369            reshared_projects: rejoined_room
1370                .reshared_projects
1371                .iter()
1372                .map(|project| proto::ResharedProject {
1373                    id: project.id.to_proto(),
1374                    collaborators: project
1375                        .collaborators
1376                        .iter()
1377                        .map(|collaborator| collaborator.to_proto())
1378                        .collect(),
1379                })
1380                .collect(),
1381            rejoined_projects: rejoined_room
1382                .rejoined_projects
1383                .iter()
1384                .map(|rejoined_project| rejoined_project.to_proto())
1385                .collect(),
1386        })?;
1387        room_updated(&rejoined_room.room, &session.peer);
1388
1389        for project in &rejoined_room.reshared_projects {
1390            for collaborator in &project.collaborators {
1391                session
1392                    .peer
1393                    .send(
1394                        collaborator.connection_id,
1395                        proto::UpdateProjectCollaborator {
1396                            project_id: project.id.to_proto(),
1397                            old_peer_id: Some(project.old_connection_id.into()),
1398                            new_peer_id: Some(session.connection_id.into()),
1399                        },
1400                    )
1401                    .trace_err();
1402            }
1403
1404            broadcast(
1405                Some(session.connection_id),
1406                project
1407                    .collaborators
1408                    .iter()
1409                    .map(|collaborator| collaborator.connection_id),
1410                |connection_id| {
1411                    session.peer.forward_send(
1412                        session.connection_id,
1413                        connection_id,
1414                        proto::UpdateProject {
1415                            project_id: project.id.to_proto(),
1416                            worktrees: project.worktrees.clone(),
1417                        },
1418                    )
1419                },
1420            );
1421        }
1422
1423        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1424
1425        let rejoined_room = rejoined_room.into_inner();
1426
1427        room = rejoined_room.room;
1428        channel = rejoined_room.channel;
1429    }
1430
1431    if let Some(channel) = channel {
1432        channel_updated(
1433            &channel,
1434            &room,
1435            &session.peer,
1436            &*session.connection_pool().await,
1437        );
1438    }
1439
1440    update_user_contacts(session.user_id(), &session).await?;
1441    Ok(())
1442}
1443
1444fn notify_rejoined_projects(
1445    rejoined_projects: &mut Vec<RejoinedProject>,
1446    session: &Session,
1447) -> Result<()> {
1448    for project in rejoined_projects.iter() {
1449        for collaborator in &project.collaborators {
1450            session
1451                .peer
1452                .send(
1453                    collaborator.connection_id,
1454                    proto::UpdateProjectCollaborator {
1455                        project_id: project.id.to_proto(),
1456                        old_peer_id: Some(project.old_connection_id.into()),
1457                        new_peer_id: Some(session.connection_id.into()),
1458                    },
1459                )
1460                .trace_err();
1461        }
1462    }
1463
1464    for project in rejoined_projects {
1465        for worktree in mem::take(&mut project.worktrees) {
1466            // Stream this worktree's entries.
1467            let message = proto::UpdateWorktree {
1468                project_id: project.id.to_proto(),
1469                worktree_id: worktree.id,
1470                abs_path: worktree.abs_path.clone(),
1471                root_name: worktree.root_name,
1472                updated_entries: worktree.updated_entries,
1473                removed_entries: worktree.removed_entries,
1474                scan_id: worktree.scan_id,
1475                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1476                updated_repositories: worktree.updated_repositories,
1477                removed_repositories: worktree.removed_repositories,
1478            };
1479            for update in proto::split_worktree_update(message) {
1480                session.peer.send(session.connection_id, update)?;
1481            }
1482
1483            // Stream this worktree's diagnostics.
1484            for summary in worktree.diagnostic_summaries {
1485                session.peer.send(
1486                    session.connection_id,
1487                    proto::UpdateDiagnosticSummary {
1488                        project_id: project.id.to_proto(),
1489                        worktree_id: worktree.id,
1490                        summary: Some(summary),
1491                    },
1492                )?;
1493            }
1494
1495            for settings_file in worktree.settings_files {
1496                session.peer.send(
1497                    session.connection_id,
1498                    proto::UpdateWorktreeSettings {
1499                        project_id: project.id.to_proto(),
1500                        worktree_id: worktree.id,
1501                        path: settings_file.path,
1502                        content: Some(settings_file.content),
1503                        kind: Some(settings_file.kind.to_proto().into()),
1504                    },
1505                )?;
1506            }
1507        }
1508
1509        for repository in mem::take(&mut project.updated_repositories) {
1510            for update in split_repository_update(repository) {
1511                session.peer.send(session.connection_id, update)?;
1512            }
1513        }
1514
1515        for id in mem::take(&mut project.removed_repositories) {
1516            session.peer.send(
1517                session.connection_id,
1518                proto::RemoveRepository {
1519                    project_id: project.id.to_proto(),
1520                    id,
1521                },
1522            )?;
1523        }
1524    }
1525
1526    Ok(())
1527}
1528
1529/// leave room disconnects from the room.
1530async fn leave_room(
1531    _: proto::LeaveRoom,
1532    response: Response<proto::LeaveRoom>,
1533    session: Session,
1534) -> Result<()> {
1535    leave_room_for_session(&session, session.connection_id).await?;
1536    response.send(proto::Ack {})?;
1537    Ok(())
1538}
1539
1540/// Updates the permissions of someone else in the room.
1541async fn set_room_participant_role(
1542    request: proto::SetRoomParticipantRole,
1543    response: Response<proto::SetRoomParticipantRole>,
1544    session: Session,
1545) -> Result<()> {
1546    let user_id = UserId::from_proto(request.user_id);
1547    let role = ChannelRole::from(request.role());
1548
1549    let (livekit_room, can_publish) = {
1550        let room = session
1551            .db()
1552            .await
1553            .set_room_participant_role(
1554                session.user_id(),
1555                RoomId::from_proto(request.room_id),
1556                user_id,
1557                role,
1558            )
1559            .await?;
1560
1561        let livekit_room = room.livekit_room.clone();
1562        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1563        room_updated(&room, &session.peer);
1564        (livekit_room, can_publish)
1565    };
1566
1567    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1568        live_kit
1569            .update_participant(
1570                livekit_room.clone(),
1571                request.user_id.to_string(),
1572                livekit_api::proto::ParticipantPermission {
1573                    can_subscribe: true,
1574                    can_publish,
1575                    can_publish_data: can_publish,
1576                    hidden: false,
1577                    recorder: false,
1578                },
1579            )
1580            .await
1581            .trace_err();
1582    }
1583
1584    response.send(proto::Ack {})?;
1585    Ok(())
1586}
1587
1588/// Call someone else into the current room
1589async fn call(
1590    request: proto::Call,
1591    response: Response<proto::Call>,
1592    session: Session,
1593) -> Result<()> {
1594    let room_id = RoomId::from_proto(request.room_id);
1595    let calling_user_id = session.user_id();
1596    let calling_connection_id = session.connection_id;
1597    let called_user_id = UserId::from_proto(request.called_user_id);
1598    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1599    if !session
1600        .db()
1601        .await
1602        .has_contact(calling_user_id, called_user_id)
1603        .await?
1604    {
1605        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1606    }
1607
1608    let incoming_call = {
1609        let (room, incoming_call) = &mut *session
1610            .db()
1611            .await
1612            .call(
1613                room_id,
1614                calling_user_id,
1615                calling_connection_id,
1616                called_user_id,
1617                initial_project_id,
1618            )
1619            .await?;
1620        room_updated(room, &session.peer);
1621        mem::take(incoming_call)
1622    };
1623    update_user_contacts(called_user_id, &session).await?;
1624
1625    let mut calls = session
1626        .connection_pool()
1627        .await
1628        .user_connection_ids(called_user_id)
1629        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1630        .collect::<FuturesUnordered<_>>();
1631
1632    while let Some(call_response) = calls.next().await {
1633        match call_response.as_ref() {
1634            Ok(_) => {
1635                response.send(proto::Ack {})?;
1636                return Ok(());
1637            }
1638            Err(_) => {
1639                call_response.trace_err();
1640            }
1641        }
1642    }
1643
1644    {
1645        let room = session
1646            .db()
1647            .await
1648            .call_failed(room_id, called_user_id)
1649            .await?;
1650        room_updated(&room, &session.peer);
1651    }
1652    update_user_contacts(called_user_id, &session).await?;
1653
1654    Err(anyhow!("failed to ring user"))?
1655}
1656
1657/// Cancel an outgoing call.
1658async fn cancel_call(
1659    request: proto::CancelCall,
1660    response: Response<proto::CancelCall>,
1661    session: Session,
1662) -> Result<()> {
1663    let called_user_id = UserId::from_proto(request.called_user_id);
1664    let room_id = RoomId::from_proto(request.room_id);
1665    {
1666        let room = session
1667            .db()
1668            .await
1669            .cancel_call(room_id, session.connection_id, called_user_id)
1670            .await?;
1671        room_updated(&room, &session.peer);
1672    }
1673
1674    for connection_id in session
1675        .connection_pool()
1676        .await
1677        .user_connection_ids(called_user_id)
1678    {
1679        session
1680            .peer
1681            .send(
1682                connection_id,
1683                proto::CallCanceled {
1684                    room_id: room_id.to_proto(),
1685                },
1686            )
1687            .trace_err();
1688    }
1689    response.send(proto::Ack {})?;
1690
1691    update_user_contacts(called_user_id, &session).await?;
1692    Ok(())
1693}
1694
1695/// Decline an incoming call.
1696async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1697    let room_id = RoomId::from_proto(message.room_id);
1698    {
1699        let room = session
1700            .db()
1701            .await
1702            .decline_call(Some(room_id), session.user_id())
1703            .await?
1704            .ok_or_else(|| anyhow!("failed to decline call"))?;
1705        room_updated(&room, &session.peer);
1706    }
1707
1708    for connection_id in session
1709        .connection_pool()
1710        .await
1711        .user_connection_ids(session.user_id())
1712    {
1713        session
1714            .peer
1715            .send(
1716                connection_id,
1717                proto::CallCanceled {
1718                    room_id: room_id.to_proto(),
1719                },
1720            )
1721            .trace_err();
1722    }
1723    update_user_contacts(session.user_id(), &session).await?;
1724    Ok(())
1725}
1726
1727/// Updates other participants in the room with your current location.
1728async fn update_participant_location(
1729    request: proto::UpdateParticipantLocation,
1730    response: Response<proto::UpdateParticipantLocation>,
1731    session: Session,
1732) -> Result<()> {
1733    let room_id = RoomId::from_proto(request.room_id);
1734    let location = request
1735        .location
1736        .ok_or_else(|| anyhow!("invalid location"))?;
1737
1738    let db = session.db().await;
1739    let room = db
1740        .update_room_participant_location(room_id, session.connection_id, location)
1741        .await?;
1742
1743    room_updated(&room, &session.peer);
1744    response.send(proto::Ack {})?;
1745    Ok(())
1746}
1747
1748/// Share a project into the room.
1749async fn share_project(
1750    request: proto::ShareProject,
1751    response: Response<proto::ShareProject>,
1752    session: Session,
1753) -> Result<()> {
1754    let (project_id, room) = &*session
1755        .db()
1756        .await
1757        .share_project(
1758            RoomId::from_proto(request.room_id),
1759            session.connection_id,
1760            &request.worktrees,
1761            request.is_ssh_project,
1762        )
1763        .await?;
1764    response.send(proto::ShareProjectResponse {
1765        project_id: project_id.to_proto(),
1766    })?;
1767    room_updated(room, &session.peer);
1768
1769    Ok(())
1770}
1771
1772/// Unshare a project from the room.
1773async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1774    let project_id = ProjectId::from_proto(message.project_id);
1775    unshare_project_internal(project_id, session.connection_id, &session).await
1776}
1777
1778async fn unshare_project_internal(
1779    project_id: ProjectId,
1780    connection_id: ConnectionId,
1781    session: &Session,
1782) -> Result<()> {
1783    let delete = {
1784        let room_guard = session
1785            .db()
1786            .await
1787            .unshare_project(project_id, connection_id)
1788            .await?;
1789
1790        let (delete, room, guest_connection_ids) = &*room_guard;
1791
1792        let message = proto::UnshareProject {
1793            project_id: project_id.to_proto(),
1794        };
1795
1796        broadcast(
1797            Some(connection_id),
1798            guest_connection_ids.iter().copied(),
1799            |conn_id| session.peer.send(conn_id, message.clone()),
1800        );
1801        if let Some(room) = room {
1802            room_updated(room, &session.peer);
1803        }
1804
1805        *delete
1806    };
1807
1808    if delete {
1809        let db = session.db().await;
1810        db.delete_project(project_id).await?;
1811    }
1812
1813    Ok(())
1814}
1815
1816/// Join someone elses shared project.
1817async fn join_project(
1818    request: proto::JoinProject,
1819    response: Response<proto::JoinProject>,
1820    session: Session,
1821) -> Result<()> {
1822    let project_id = ProjectId::from_proto(request.project_id);
1823
1824    tracing::info!(%project_id, "join project");
1825
1826    let db = session.db().await;
1827    let (project, replica_id) = &mut *db
1828        .join_project(project_id, session.connection_id, session.user_id())
1829        .await?;
1830    drop(db);
1831    tracing::info!(%project_id, "join remote project");
1832    join_project_internal(response, session, project, replica_id)
1833}
1834
1835trait JoinProjectInternalResponse {
1836    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1837}
1838impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1839    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1840        Response::<proto::JoinProject>::send(self, result)
1841    }
1842}
1843
1844fn join_project_internal(
1845    response: impl JoinProjectInternalResponse,
1846    session: Session,
1847    project: &mut Project,
1848    replica_id: &ReplicaId,
1849) -> Result<()> {
1850    let collaborators = project
1851        .collaborators
1852        .iter()
1853        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1854        .map(|collaborator| collaborator.to_proto())
1855        .collect::<Vec<_>>();
1856    let project_id = project.id;
1857    let guest_user_id = session.user_id();
1858
1859    let worktrees = project
1860        .worktrees
1861        .iter()
1862        .map(|(id, worktree)| proto::WorktreeMetadata {
1863            id: *id,
1864            root_name: worktree.root_name.clone(),
1865            visible: worktree.visible,
1866            abs_path: worktree.abs_path.clone(),
1867        })
1868        .collect::<Vec<_>>();
1869
1870    let add_project_collaborator = proto::AddProjectCollaborator {
1871        project_id: project_id.to_proto(),
1872        collaborator: Some(proto::Collaborator {
1873            peer_id: Some(session.connection_id.into()),
1874            replica_id: replica_id.0 as u32,
1875            user_id: guest_user_id.to_proto(),
1876            is_host: false,
1877        }),
1878    };
1879
1880    for collaborator in &collaborators {
1881        session
1882            .peer
1883            .send(
1884                collaborator.peer_id.unwrap().into(),
1885                add_project_collaborator.clone(),
1886            )
1887            .trace_err();
1888    }
1889
1890    // First, we send the metadata associated with each worktree.
1891    response.send(proto::JoinProjectResponse {
1892        project_id: project.id.0 as u64,
1893        worktrees: worktrees.clone(),
1894        replica_id: replica_id.0 as u32,
1895        collaborators: collaborators.clone(),
1896        language_servers: project.language_servers.clone(),
1897        role: project.role.into(),
1898    })?;
1899
1900    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1901        // Stream this worktree's entries.
1902        let message = proto::UpdateWorktree {
1903            project_id: project_id.to_proto(),
1904            worktree_id,
1905            abs_path: worktree.abs_path.clone(),
1906            root_name: worktree.root_name,
1907            updated_entries: worktree.entries,
1908            removed_entries: Default::default(),
1909            scan_id: worktree.scan_id,
1910            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1911            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1912            removed_repositories: Default::default(),
1913        };
1914        for update in proto::split_worktree_update(message) {
1915            session.peer.send(session.connection_id, update.clone())?;
1916        }
1917
1918        // Stream this worktree's diagnostics.
1919        for summary in worktree.diagnostic_summaries {
1920            session.peer.send(
1921                session.connection_id,
1922                proto::UpdateDiagnosticSummary {
1923                    project_id: project_id.to_proto(),
1924                    worktree_id: worktree.id,
1925                    summary: Some(summary),
1926                },
1927            )?;
1928        }
1929
1930        for settings_file in worktree.settings_files {
1931            session.peer.send(
1932                session.connection_id,
1933                proto::UpdateWorktreeSettings {
1934                    project_id: project_id.to_proto(),
1935                    worktree_id: worktree.id,
1936                    path: settings_file.path,
1937                    content: Some(settings_file.content),
1938                    kind: Some(settings_file.kind.to_proto() as i32),
1939                },
1940            )?;
1941        }
1942    }
1943
1944    for repository in mem::take(&mut project.repositories) {
1945        for update in split_repository_update(repository) {
1946            session.peer.send(session.connection_id, update)?;
1947        }
1948    }
1949
1950    for language_server in &project.language_servers {
1951        session.peer.send(
1952            session.connection_id,
1953            proto::UpdateLanguageServer {
1954                project_id: project_id.to_proto(),
1955                language_server_id: language_server.id,
1956                variant: Some(
1957                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1958                        proto::LspDiskBasedDiagnosticsUpdated {},
1959                    ),
1960                ),
1961            },
1962        )?;
1963    }
1964
1965    Ok(())
1966}
1967
1968/// Leave someone elses shared project.
1969async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1970    let sender_id = session.connection_id;
1971    let project_id = ProjectId::from_proto(request.project_id);
1972    let db = session.db().await;
1973
1974    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1975    tracing::info!(
1976        %project_id,
1977        "leave project"
1978    );
1979
1980    project_left(project, &session);
1981    if let Some(room) = room {
1982        room_updated(room, &session.peer);
1983    }
1984
1985    Ok(())
1986}
1987
1988/// Updates other participants with changes to the project
1989async fn update_project(
1990    request: proto::UpdateProject,
1991    response: Response<proto::UpdateProject>,
1992    session: Session,
1993) -> Result<()> {
1994    let project_id = ProjectId::from_proto(request.project_id);
1995    let (room, guest_connection_ids) = &*session
1996        .db()
1997        .await
1998        .update_project(project_id, session.connection_id, &request.worktrees)
1999        .await?;
2000    broadcast(
2001        Some(session.connection_id),
2002        guest_connection_ids.iter().copied(),
2003        |connection_id| {
2004            session
2005                .peer
2006                .forward_send(session.connection_id, connection_id, request.clone())
2007        },
2008    );
2009    if let Some(room) = room {
2010        room_updated(room, &session.peer);
2011    }
2012    response.send(proto::Ack {})?;
2013
2014    Ok(())
2015}
2016
2017/// Updates other participants with changes to the worktree
2018async fn update_worktree(
2019    request: proto::UpdateWorktree,
2020    response: Response<proto::UpdateWorktree>,
2021    session: Session,
2022) -> Result<()> {
2023    let guest_connection_ids = session
2024        .db()
2025        .await
2026        .update_worktree(&request, session.connection_id)
2027        .await?;
2028
2029    broadcast(
2030        Some(session.connection_id),
2031        guest_connection_ids.iter().copied(),
2032        |connection_id| {
2033            session
2034                .peer
2035                .forward_send(session.connection_id, connection_id, request.clone())
2036        },
2037    );
2038    response.send(proto::Ack {})?;
2039    Ok(())
2040}
2041
2042async fn update_repository(
2043    request: proto::UpdateRepository,
2044    response: Response<proto::UpdateRepository>,
2045    session: Session,
2046) -> Result<()> {
2047    let guest_connection_ids = session
2048        .db()
2049        .await
2050        .update_repository(&request, session.connection_id)
2051        .await?;
2052
2053    broadcast(
2054        Some(session.connection_id),
2055        guest_connection_ids.iter().copied(),
2056        |connection_id| {
2057            session
2058                .peer
2059                .forward_send(session.connection_id, connection_id, request.clone())
2060        },
2061    );
2062    response.send(proto::Ack {})?;
2063    Ok(())
2064}
2065
2066async fn remove_repository(
2067    request: proto::RemoveRepository,
2068    response: Response<proto::RemoveRepository>,
2069    session: Session,
2070) -> Result<()> {
2071    let guest_connection_ids = session
2072        .db()
2073        .await
2074        .remove_repository(&request, session.connection_id)
2075        .await?;
2076
2077    broadcast(
2078        Some(session.connection_id),
2079        guest_connection_ids.iter().copied(),
2080        |connection_id| {
2081            session
2082                .peer
2083                .forward_send(session.connection_id, connection_id, request.clone())
2084        },
2085    );
2086    response.send(proto::Ack {})?;
2087    Ok(())
2088}
2089
2090/// Updates other participants with changes to the diagnostics
2091async fn update_diagnostic_summary(
2092    message: proto::UpdateDiagnosticSummary,
2093    session: Session,
2094) -> Result<()> {
2095    let guest_connection_ids = session
2096        .db()
2097        .await
2098        .update_diagnostic_summary(&message, session.connection_id)
2099        .await?;
2100
2101    broadcast(
2102        Some(session.connection_id),
2103        guest_connection_ids.iter().copied(),
2104        |connection_id| {
2105            session
2106                .peer
2107                .forward_send(session.connection_id, connection_id, message.clone())
2108        },
2109    );
2110
2111    Ok(())
2112}
2113
2114/// Updates other participants with changes to the worktree settings
2115async fn update_worktree_settings(
2116    message: proto::UpdateWorktreeSettings,
2117    session: Session,
2118) -> Result<()> {
2119    let guest_connection_ids = session
2120        .db()
2121        .await
2122        .update_worktree_settings(&message, session.connection_id)
2123        .await?;
2124
2125    broadcast(
2126        Some(session.connection_id),
2127        guest_connection_ids.iter().copied(),
2128        |connection_id| {
2129            session
2130                .peer
2131                .forward_send(session.connection_id, connection_id, message.clone())
2132        },
2133    );
2134
2135    Ok(())
2136}
2137
2138/// Notify other participants that a language server has started.
2139async fn start_language_server(
2140    request: proto::StartLanguageServer,
2141    session: Session,
2142) -> Result<()> {
2143    let guest_connection_ids = session
2144        .db()
2145        .await
2146        .start_language_server(&request, session.connection_id)
2147        .await?;
2148
2149    broadcast(
2150        Some(session.connection_id),
2151        guest_connection_ids.iter().copied(),
2152        |connection_id| {
2153            session
2154                .peer
2155                .forward_send(session.connection_id, connection_id, request.clone())
2156        },
2157    );
2158    Ok(())
2159}
2160
2161/// Notify other participants that a language server has changed.
2162async fn update_language_server(
2163    request: proto::UpdateLanguageServer,
2164    session: Session,
2165) -> Result<()> {
2166    let project_id = ProjectId::from_proto(request.project_id);
2167    let project_connection_ids = session
2168        .db()
2169        .await
2170        .project_connection_ids(project_id, session.connection_id, true)
2171        .await?;
2172    broadcast(
2173        Some(session.connection_id),
2174        project_connection_ids.iter().copied(),
2175        |connection_id| {
2176            session
2177                .peer
2178                .forward_send(session.connection_id, connection_id, request.clone())
2179        },
2180    );
2181    Ok(())
2182}
2183
2184/// forward a project request to the host. These requests should be read only
2185/// as guests are allowed to send them.
2186async fn forward_read_only_project_request<T>(
2187    request: T,
2188    response: Response<T>,
2189    session: Session,
2190) -> Result<()>
2191where
2192    T: EntityMessage + RequestMessage,
2193{
2194    let project_id = ProjectId::from_proto(request.remote_entity_id());
2195    let host_connection_id = session
2196        .db()
2197        .await
2198        .host_for_read_only_project_request(project_id, session.connection_id)
2199        .await?;
2200    let payload = session
2201        .peer
2202        .forward_request(session.connection_id, host_connection_id, request)
2203        .await?;
2204    response.send(payload)?;
2205    Ok(())
2206}
2207
2208async fn forward_find_search_candidates_request(
2209    request: proto::FindSearchCandidates,
2210    response: Response<proto::FindSearchCandidates>,
2211    session: Session,
2212) -> Result<()> {
2213    let project_id = ProjectId::from_proto(request.remote_entity_id());
2214    let host_connection_id = session
2215        .db()
2216        .await
2217        .host_for_read_only_project_request(project_id, session.connection_id)
2218        .await?;
2219    let payload = session
2220        .peer
2221        .forward_request(session.connection_id, host_connection_id, request)
2222        .await?;
2223    response.send(payload)?;
2224    Ok(())
2225}
2226
2227/// forward a project request to the host. These requests are disallowed
2228/// for guests.
2229async fn forward_mutating_project_request<T>(
2230    request: T,
2231    response: Response<T>,
2232    session: Session,
2233) -> Result<()>
2234where
2235    T: EntityMessage + RequestMessage,
2236{
2237    let project_id = ProjectId::from_proto(request.remote_entity_id());
2238
2239    let host_connection_id = session
2240        .db()
2241        .await
2242        .host_for_mutating_project_request(project_id, session.connection_id)
2243        .await?;
2244    let payload = session
2245        .peer
2246        .forward_request(session.connection_id, host_connection_id, request)
2247        .await?;
2248    response.send(payload)?;
2249    Ok(())
2250}
2251
2252/// Notify other participants that a new buffer has been created
2253async fn create_buffer_for_peer(
2254    request: proto::CreateBufferForPeer,
2255    session: Session,
2256) -> Result<()> {
2257    session
2258        .db()
2259        .await
2260        .check_user_is_project_host(
2261            ProjectId::from_proto(request.project_id),
2262            session.connection_id,
2263        )
2264        .await?;
2265    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2266    session
2267        .peer
2268        .forward_send(session.connection_id, peer_id.into(), request)?;
2269    Ok(())
2270}
2271
2272/// Notify other participants that a buffer has been updated. This is
2273/// allowed for guests as long as the update is limited to selections.
2274async fn update_buffer(
2275    request: proto::UpdateBuffer,
2276    response: Response<proto::UpdateBuffer>,
2277    session: Session,
2278) -> Result<()> {
2279    let project_id = ProjectId::from_proto(request.project_id);
2280    let mut capability = Capability::ReadOnly;
2281
2282    for op in request.operations.iter() {
2283        match op.variant {
2284            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2285            Some(_) => capability = Capability::ReadWrite,
2286        }
2287    }
2288
2289    let host = {
2290        let guard = session
2291            .db()
2292            .await
2293            .connections_for_buffer_update(project_id, session.connection_id, capability)
2294            .await?;
2295
2296        let (host, guests) = &*guard;
2297
2298        broadcast(
2299            Some(session.connection_id),
2300            guests.clone(),
2301            |connection_id| {
2302                session
2303                    .peer
2304                    .forward_send(session.connection_id, connection_id, request.clone())
2305            },
2306        );
2307
2308        *host
2309    };
2310
2311    if host != session.connection_id {
2312        session
2313            .peer
2314            .forward_request(session.connection_id, host, request.clone())
2315            .await?;
2316    }
2317
2318    response.send(proto::Ack {})?;
2319    Ok(())
2320}
2321
2322async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2323    let project_id = ProjectId::from_proto(message.project_id);
2324
2325    let operation = message.operation.as_ref().context("invalid operation")?;
2326    let capability = match operation.variant.as_ref() {
2327        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2328            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2329                match buffer_op.variant {
2330                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2331                        Capability::ReadOnly
2332                    }
2333                    _ => Capability::ReadWrite,
2334                }
2335            } else {
2336                Capability::ReadWrite
2337            }
2338        }
2339        Some(_) => Capability::ReadWrite,
2340        None => Capability::ReadOnly,
2341    };
2342
2343    let guard = session
2344        .db()
2345        .await
2346        .connections_for_buffer_update(project_id, session.connection_id, capability)
2347        .await?;
2348
2349    let (host, guests) = &*guard;
2350
2351    broadcast(
2352        Some(session.connection_id),
2353        guests.iter().chain([host]).copied(),
2354        |connection_id| {
2355            session
2356                .peer
2357                .forward_send(session.connection_id, connection_id, message.clone())
2358        },
2359    );
2360
2361    Ok(())
2362}
2363
2364/// Notify other participants that a project has been updated.
2365async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2366    request: T,
2367    session: Session,
2368) -> Result<()> {
2369    let project_id = ProjectId::from_proto(request.remote_entity_id());
2370    let project_connection_ids = session
2371        .db()
2372        .await
2373        .project_connection_ids(project_id, session.connection_id, false)
2374        .await?;
2375
2376    broadcast(
2377        Some(session.connection_id),
2378        project_connection_ids.iter().copied(),
2379        |connection_id| {
2380            session
2381                .peer
2382                .forward_send(session.connection_id, connection_id, request.clone())
2383        },
2384    );
2385    Ok(())
2386}
2387
2388/// Start following another user in a call.
2389async fn follow(
2390    request: proto::Follow,
2391    response: Response<proto::Follow>,
2392    session: Session,
2393) -> Result<()> {
2394    let room_id = RoomId::from_proto(request.room_id);
2395    let project_id = request.project_id.map(ProjectId::from_proto);
2396    let leader_id = request
2397        .leader_id
2398        .ok_or_else(|| anyhow!("invalid leader id"))?
2399        .into();
2400    let follower_id = session.connection_id;
2401
2402    session
2403        .db()
2404        .await
2405        .check_room_participants(room_id, leader_id, session.connection_id)
2406        .await?;
2407
2408    let response_payload = session
2409        .peer
2410        .forward_request(session.connection_id, leader_id, request)
2411        .await?;
2412    response.send(response_payload)?;
2413
2414    if let Some(project_id) = project_id {
2415        let room = session
2416            .db()
2417            .await
2418            .follow(room_id, project_id, leader_id, follower_id)
2419            .await?;
2420        room_updated(&room, &session.peer);
2421    }
2422
2423    Ok(())
2424}
2425
2426/// Stop following another user in a call.
2427async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2428    let room_id = RoomId::from_proto(request.room_id);
2429    let project_id = request.project_id.map(ProjectId::from_proto);
2430    let leader_id = request
2431        .leader_id
2432        .ok_or_else(|| anyhow!("invalid leader id"))?
2433        .into();
2434    let follower_id = session.connection_id;
2435
2436    session
2437        .db()
2438        .await
2439        .check_room_participants(room_id, leader_id, session.connection_id)
2440        .await?;
2441
2442    session
2443        .peer
2444        .forward_send(session.connection_id, leader_id, request)?;
2445
2446    if let Some(project_id) = project_id {
2447        let room = session
2448            .db()
2449            .await
2450            .unfollow(room_id, project_id, leader_id, follower_id)
2451            .await?;
2452        room_updated(&room, &session.peer);
2453    }
2454
2455    Ok(())
2456}
2457
2458/// Notify everyone following you of your current location.
2459async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2460    let room_id = RoomId::from_proto(request.room_id);
2461    let database = session.db.lock().await;
2462
2463    let connection_ids = if let Some(project_id) = request.project_id {
2464        let project_id = ProjectId::from_proto(project_id);
2465        database
2466            .project_connection_ids(project_id, session.connection_id, true)
2467            .await?
2468    } else {
2469        database
2470            .room_connection_ids(room_id, session.connection_id)
2471            .await?
2472    };
2473
2474    // For now, don't send view update messages back to that view's current leader.
2475    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2476        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2477        _ => None,
2478    });
2479
2480    for connection_id in connection_ids.iter().cloned() {
2481        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2482            session
2483                .peer
2484                .forward_send(session.connection_id, connection_id, request.clone())?;
2485        }
2486    }
2487    Ok(())
2488}
2489
2490/// Get public data about users.
2491async fn get_users(
2492    request: proto::GetUsers,
2493    response: Response<proto::GetUsers>,
2494    session: Session,
2495) -> Result<()> {
2496    let user_ids = request
2497        .user_ids
2498        .into_iter()
2499        .map(UserId::from_proto)
2500        .collect();
2501    let users = session
2502        .db()
2503        .await
2504        .get_users_by_ids(user_ids)
2505        .await?
2506        .into_iter()
2507        .map(|user| proto::User {
2508            id: user.id.to_proto(),
2509            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2510            github_login: user.github_login,
2511            email: user.email_address,
2512            name: user.name,
2513        })
2514        .collect();
2515    response.send(proto::UsersResponse { users })?;
2516    Ok(())
2517}
2518
2519/// Search for users (to invite) buy Github login
2520async fn fuzzy_search_users(
2521    request: proto::FuzzySearchUsers,
2522    response: Response<proto::FuzzySearchUsers>,
2523    session: Session,
2524) -> Result<()> {
2525    let query = request.query;
2526    let users = match query.len() {
2527        0 => vec![],
2528        1 | 2 => session
2529            .db()
2530            .await
2531            .get_user_by_github_login(&query)
2532            .await?
2533            .into_iter()
2534            .collect(),
2535        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2536    };
2537    let users = users
2538        .into_iter()
2539        .filter(|user| user.id != session.user_id())
2540        .map(|user| proto::User {
2541            id: user.id.to_proto(),
2542            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2543            github_login: user.github_login,
2544            name: user.name,
2545            email: user.email_address,
2546        })
2547        .collect();
2548    response.send(proto::UsersResponse { users })?;
2549    Ok(())
2550}
2551
2552/// Send a contact request to another user.
2553async fn request_contact(
2554    request: proto::RequestContact,
2555    response: Response<proto::RequestContact>,
2556    session: Session,
2557) -> Result<()> {
2558    let requester_id = session.user_id();
2559    let responder_id = UserId::from_proto(request.responder_id);
2560    if requester_id == responder_id {
2561        return Err(anyhow!("cannot add yourself as a contact"))?;
2562    }
2563
2564    let notifications = session
2565        .db()
2566        .await
2567        .send_contact_request(requester_id, responder_id)
2568        .await?;
2569
2570    // Update outgoing contact requests of requester
2571    let mut update = proto::UpdateContacts::default();
2572    update.outgoing_requests.push(responder_id.to_proto());
2573    for connection_id in session
2574        .connection_pool()
2575        .await
2576        .user_connection_ids(requester_id)
2577    {
2578        session.peer.send(connection_id, update.clone())?;
2579    }
2580
2581    // Update incoming contact requests of responder
2582    let mut update = proto::UpdateContacts::default();
2583    update
2584        .incoming_requests
2585        .push(proto::IncomingContactRequest {
2586            requester_id: requester_id.to_proto(),
2587        });
2588    let connection_pool = session.connection_pool().await;
2589    for connection_id in connection_pool.user_connection_ids(responder_id) {
2590        session.peer.send(connection_id, update.clone())?;
2591    }
2592
2593    send_notifications(&connection_pool, &session.peer, notifications);
2594
2595    response.send(proto::Ack {})?;
2596    Ok(())
2597}
2598
2599/// Accept or decline a contact request
2600async fn respond_to_contact_request(
2601    request: proto::RespondToContactRequest,
2602    response: Response<proto::RespondToContactRequest>,
2603    session: Session,
2604) -> Result<()> {
2605    let responder_id = session.user_id();
2606    let requester_id = UserId::from_proto(request.requester_id);
2607    let db = session.db().await;
2608    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2609        db.dismiss_contact_notification(responder_id, requester_id)
2610            .await?;
2611    } else {
2612        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2613
2614        let notifications = db
2615            .respond_to_contact_request(responder_id, requester_id, accept)
2616            .await?;
2617        let requester_busy = db.is_user_busy(requester_id).await?;
2618        let responder_busy = db.is_user_busy(responder_id).await?;
2619
2620        let pool = session.connection_pool().await;
2621        // Update responder with new contact
2622        let mut update = proto::UpdateContacts::default();
2623        if accept {
2624            update
2625                .contacts
2626                .push(contact_for_user(requester_id, requester_busy, &pool));
2627        }
2628        update
2629            .remove_incoming_requests
2630            .push(requester_id.to_proto());
2631        for connection_id in pool.user_connection_ids(responder_id) {
2632            session.peer.send(connection_id, update.clone())?;
2633        }
2634
2635        // Update requester with new contact
2636        let mut update = proto::UpdateContacts::default();
2637        if accept {
2638            update
2639                .contacts
2640                .push(contact_for_user(responder_id, responder_busy, &pool));
2641        }
2642        update
2643            .remove_outgoing_requests
2644            .push(responder_id.to_proto());
2645
2646        for connection_id in pool.user_connection_ids(requester_id) {
2647            session.peer.send(connection_id, update.clone())?;
2648        }
2649
2650        send_notifications(&pool, &session.peer, notifications);
2651    }
2652
2653    response.send(proto::Ack {})?;
2654    Ok(())
2655}
2656
2657/// Remove a contact.
2658async fn remove_contact(
2659    request: proto::RemoveContact,
2660    response: Response<proto::RemoveContact>,
2661    session: Session,
2662) -> Result<()> {
2663    let requester_id = session.user_id();
2664    let responder_id = UserId::from_proto(request.user_id);
2665    let db = session.db().await;
2666    let (contact_accepted, deleted_notification_id) =
2667        db.remove_contact(requester_id, responder_id).await?;
2668
2669    let pool = session.connection_pool().await;
2670    // Update outgoing contact requests of requester
2671    let mut update = proto::UpdateContacts::default();
2672    if contact_accepted {
2673        update.remove_contacts.push(responder_id.to_proto());
2674    } else {
2675        update
2676            .remove_outgoing_requests
2677            .push(responder_id.to_proto());
2678    }
2679    for connection_id in pool.user_connection_ids(requester_id) {
2680        session.peer.send(connection_id, update.clone())?;
2681    }
2682
2683    // Update incoming contact requests of responder
2684    let mut update = proto::UpdateContacts::default();
2685    if contact_accepted {
2686        update.remove_contacts.push(requester_id.to_proto());
2687    } else {
2688        update
2689            .remove_incoming_requests
2690            .push(requester_id.to_proto());
2691    }
2692    for connection_id in pool.user_connection_ids(responder_id) {
2693        session.peer.send(connection_id, update.clone())?;
2694        if let Some(notification_id) = deleted_notification_id {
2695            session.peer.send(
2696                connection_id,
2697                proto::DeleteNotification {
2698                    notification_id: notification_id.to_proto(),
2699                },
2700            )?;
2701        }
2702    }
2703
2704    response.send(proto::Ack {})?;
2705    Ok(())
2706}
2707
2708fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2709    version.0.minor() < 139
2710}
2711
2712async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2713    let plan = session.current_plan(&session.db().await).await?;
2714
2715    session
2716        .peer
2717        .send(
2718            session.connection_id,
2719            proto::UpdateUserPlan { plan: plan.into() },
2720        )
2721        .trace_err();
2722
2723    Ok(())
2724}
2725
2726async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2727    subscribe_user_to_channels(session.user_id(), &session).await?;
2728    Ok(())
2729}
2730
2731async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2732    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2733    let mut pool = session.connection_pool().await;
2734    for membership in &channels_for_user.channel_memberships {
2735        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2736    }
2737    session.peer.send(
2738        session.connection_id,
2739        build_update_user_channels(&channels_for_user),
2740    )?;
2741    session.peer.send(
2742        session.connection_id,
2743        build_channels_update(channels_for_user),
2744    )?;
2745    Ok(())
2746}
2747
2748/// Creates a new channel.
2749async fn create_channel(
2750    request: proto::CreateChannel,
2751    response: Response<proto::CreateChannel>,
2752    session: Session,
2753) -> Result<()> {
2754    let db = session.db().await;
2755
2756    let parent_id = request.parent_id.map(ChannelId::from_proto);
2757    let (channel, membership) = db
2758        .create_channel(&request.name, parent_id, session.user_id())
2759        .await?;
2760
2761    let root_id = channel.root_id();
2762    let channel = Channel::from_model(channel);
2763
2764    response.send(proto::CreateChannelResponse {
2765        channel: Some(channel.to_proto()),
2766        parent_id: request.parent_id,
2767    })?;
2768
2769    let mut connection_pool = session.connection_pool().await;
2770    if let Some(membership) = membership {
2771        connection_pool.subscribe_to_channel(
2772            membership.user_id,
2773            membership.channel_id,
2774            membership.role,
2775        );
2776        let update = proto::UpdateUserChannels {
2777            channel_memberships: vec![proto::ChannelMembership {
2778                channel_id: membership.channel_id.to_proto(),
2779                role: membership.role.into(),
2780            }],
2781            ..Default::default()
2782        };
2783        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2784            session.peer.send(connection_id, update.clone())?;
2785        }
2786    }
2787
2788    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2789        if !role.can_see_channel(channel.visibility) {
2790            continue;
2791        }
2792
2793        let update = proto::UpdateChannels {
2794            channels: vec![channel.to_proto()],
2795            ..Default::default()
2796        };
2797        session.peer.send(connection_id, update.clone())?;
2798    }
2799
2800    Ok(())
2801}
2802
2803/// Delete a channel
2804async fn delete_channel(
2805    request: proto::DeleteChannel,
2806    response: Response<proto::DeleteChannel>,
2807    session: Session,
2808) -> Result<()> {
2809    let db = session.db().await;
2810
2811    let channel_id = request.channel_id;
2812    let (root_channel, removed_channels) = db
2813        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2814        .await?;
2815    response.send(proto::Ack {})?;
2816
2817    // Notify members of removed channels
2818    let mut update = proto::UpdateChannels::default();
2819    update
2820        .delete_channels
2821        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2822
2823    let connection_pool = session.connection_pool().await;
2824    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2825        session.peer.send(connection_id, update.clone())?;
2826    }
2827
2828    Ok(())
2829}
2830
2831/// Invite someone to join a channel.
2832async fn invite_channel_member(
2833    request: proto::InviteChannelMember,
2834    response: Response<proto::InviteChannelMember>,
2835    session: Session,
2836) -> Result<()> {
2837    let db = session.db().await;
2838    let channel_id = ChannelId::from_proto(request.channel_id);
2839    let invitee_id = UserId::from_proto(request.user_id);
2840    let InviteMemberResult {
2841        channel,
2842        notifications,
2843    } = db
2844        .invite_channel_member(
2845            channel_id,
2846            invitee_id,
2847            session.user_id(),
2848            request.role().into(),
2849        )
2850        .await?;
2851
2852    let update = proto::UpdateChannels {
2853        channel_invitations: vec![channel.to_proto()],
2854        ..Default::default()
2855    };
2856
2857    let connection_pool = session.connection_pool().await;
2858    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2859        session.peer.send(connection_id, update.clone())?;
2860    }
2861
2862    send_notifications(&connection_pool, &session.peer, notifications);
2863
2864    response.send(proto::Ack {})?;
2865    Ok(())
2866}
2867
2868/// remove someone from a channel
2869async fn remove_channel_member(
2870    request: proto::RemoveChannelMember,
2871    response: Response<proto::RemoveChannelMember>,
2872    session: Session,
2873) -> Result<()> {
2874    let db = session.db().await;
2875    let channel_id = ChannelId::from_proto(request.channel_id);
2876    let member_id = UserId::from_proto(request.user_id);
2877
2878    let RemoveChannelMemberResult {
2879        membership_update,
2880        notification_id,
2881    } = db
2882        .remove_channel_member(channel_id, member_id, session.user_id())
2883        .await?;
2884
2885    let mut connection_pool = session.connection_pool().await;
2886    notify_membership_updated(
2887        &mut connection_pool,
2888        membership_update,
2889        member_id,
2890        &session.peer,
2891    );
2892    for connection_id in connection_pool.user_connection_ids(member_id) {
2893        if let Some(notification_id) = notification_id {
2894            session
2895                .peer
2896                .send(
2897                    connection_id,
2898                    proto::DeleteNotification {
2899                        notification_id: notification_id.to_proto(),
2900                    },
2901                )
2902                .trace_err();
2903        }
2904    }
2905
2906    response.send(proto::Ack {})?;
2907    Ok(())
2908}
2909
2910/// Toggle the channel between public and private.
2911/// Care is taken to maintain the invariant that public channels only descend from public channels,
2912/// (though members-only channels can appear at any point in the hierarchy).
2913async fn set_channel_visibility(
2914    request: proto::SetChannelVisibility,
2915    response: Response<proto::SetChannelVisibility>,
2916    session: Session,
2917) -> Result<()> {
2918    let db = session.db().await;
2919    let channel_id = ChannelId::from_proto(request.channel_id);
2920    let visibility = request.visibility().into();
2921
2922    let channel_model = db
2923        .set_channel_visibility(channel_id, visibility, session.user_id())
2924        .await?;
2925    let root_id = channel_model.root_id();
2926    let channel = Channel::from_model(channel_model);
2927
2928    let mut connection_pool = session.connection_pool().await;
2929    for (user_id, role) in connection_pool
2930        .channel_user_ids(root_id)
2931        .collect::<Vec<_>>()
2932        .into_iter()
2933    {
2934        let update = if role.can_see_channel(channel.visibility) {
2935            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2936            proto::UpdateChannels {
2937                channels: vec![channel.to_proto()],
2938                ..Default::default()
2939            }
2940        } else {
2941            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2942            proto::UpdateChannels {
2943                delete_channels: vec![channel.id.to_proto()],
2944                ..Default::default()
2945            }
2946        };
2947
2948        for connection_id in connection_pool.user_connection_ids(user_id) {
2949            session.peer.send(connection_id, update.clone())?;
2950        }
2951    }
2952
2953    response.send(proto::Ack {})?;
2954    Ok(())
2955}
2956
2957/// Alter the role for a user in the channel.
2958async fn set_channel_member_role(
2959    request: proto::SetChannelMemberRole,
2960    response: Response<proto::SetChannelMemberRole>,
2961    session: Session,
2962) -> Result<()> {
2963    let db = session.db().await;
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    let member_id = UserId::from_proto(request.user_id);
2966    let result = db
2967        .set_channel_member_role(
2968            channel_id,
2969            session.user_id(),
2970            member_id,
2971            request.role().into(),
2972        )
2973        .await?;
2974
2975    match result {
2976        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2977            let mut connection_pool = session.connection_pool().await;
2978            notify_membership_updated(
2979                &mut connection_pool,
2980                membership_update,
2981                member_id,
2982                &session.peer,
2983            )
2984        }
2985        db::SetMemberRoleResult::InviteUpdated(channel) => {
2986            let update = proto::UpdateChannels {
2987                channel_invitations: vec![channel.to_proto()],
2988                ..Default::default()
2989            };
2990
2991            for connection_id in session
2992                .connection_pool()
2993                .await
2994                .user_connection_ids(member_id)
2995            {
2996                session.peer.send(connection_id, update.clone())?;
2997            }
2998        }
2999    }
3000
3001    response.send(proto::Ack {})?;
3002    Ok(())
3003}
3004
3005/// Change the name of a channel
3006async fn rename_channel(
3007    request: proto::RenameChannel,
3008    response: Response<proto::RenameChannel>,
3009    session: Session,
3010) -> Result<()> {
3011    let db = session.db().await;
3012    let channel_id = ChannelId::from_proto(request.channel_id);
3013    let channel_model = db
3014        .rename_channel(channel_id, session.user_id(), &request.name)
3015        .await?;
3016    let root_id = channel_model.root_id();
3017    let channel = Channel::from_model(channel_model);
3018
3019    response.send(proto::RenameChannelResponse {
3020        channel: Some(channel.to_proto()),
3021    })?;
3022
3023    let connection_pool = session.connection_pool().await;
3024    let update = proto::UpdateChannels {
3025        channels: vec![channel.to_proto()],
3026        ..Default::default()
3027    };
3028    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3029        if role.can_see_channel(channel.visibility) {
3030            session.peer.send(connection_id, update.clone())?;
3031        }
3032    }
3033
3034    Ok(())
3035}
3036
3037/// Move a channel to a new parent.
3038async fn move_channel(
3039    request: proto::MoveChannel,
3040    response: Response<proto::MoveChannel>,
3041    session: Session,
3042) -> Result<()> {
3043    let channel_id = ChannelId::from_proto(request.channel_id);
3044    let to = ChannelId::from_proto(request.to);
3045
3046    let (root_id, channels) = session
3047        .db()
3048        .await
3049        .move_channel(channel_id, to, session.user_id())
3050        .await?;
3051
3052    let connection_pool = session.connection_pool().await;
3053    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3054        let channels = channels
3055            .iter()
3056            .filter_map(|channel| {
3057                if role.can_see_channel(channel.visibility) {
3058                    Some(channel.to_proto())
3059                } else {
3060                    None
3061                }
3062            })
3063            .collect::<Vec<_>>();
3064        if channels.is_empty() {
3065            continue;
3066        }
3067
3068        let update = proto::UpdateChannels {
3069            channels,
3070            ..Default::default()
3071        };
3072
3073        session.peer.send(connection_id, update.clone())?;
3074    }
3075
3076    response.send(Ack {})?;
3077    Ok(())
3078}
3079
3080/// Get the list of channel members
3081async fn get_channel_members(
3082    request: proto::GetChannelMembers,
3083    response: Response<proto::GetChannelMembers>,
3084    session: Session,
3085) -> Result<()> {
3086    let db = session.db().await;
3087    let channel_id = ChannelId::from_proto(request.channel_id);
3088    let limit = if request.limit == 0 {
3089        u16::MAX as u64
3090    } else {
3091        request.limit
3092    };
3093    let (members, users) = db
3094        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3095        .await?;
3096    response.send(proto::GetChannelMembersResponse { members, users })?;
3097    Ok(())
3098}
3099
3100/// Accept or decline a channel invitation.
3101async fn respond_to_channel_invite(
3102    request: proto::RespondToChannelInvite,
3103    response: Response<proto::RespondToChannelInvite>,
3104    session: Session,
3105) -> Result<()> {
3106    let db = session.db().await;
3107    let channel_id = ChannelId::from_proto(request.channel_id);
3108    let RespondToChannelInvite {
3109        membership_update,
3110        notifications,
3111    } = db
3112        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3113        .await?;
3114
3115    let mut connection_pool = session.connection_pool().await;
3116    if let Some(membership_update) = membership_update {
3117        notify_membership_updated(
3118            &mut connection_pool,
3119            membership_update,
3120            session.user_id(),
3121            &session.peer,
3122        );
3123    } else {
3124        let update = proto::UpdateChannels {
3125            remove_channel_invitations: vec![channel_id.to_proto()],
3126            ..Default::default()
3127        };
3128
3129        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3130            session.peer.send(connection_id, update.clone())?;
3131        }
3132    };
3133
3134    send_notifications(&connection_pool, &session.peer, notifications);
3135
3136    response.send(proto::Ack {})?;
3137
3138    Ok(())
3139}
3140
3141/// Join the channels' room
3142async fn join_channel(
3143    request: proto::JoinChannel,
3144    response: Response<proto::JoinChannel>,
3145    session: Session,
3146) -> Result<()> {
3147    let channel_id = ChannelId::from_proto(request.channel_id);
3148    join_channel_internal(channel_id, Box::new(response), session).await
3149}
3150
3151trait JoinChannelInternalResponse {
3152    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3153}
3154impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3155    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3156        Response::<proto::JoinChannel>::send(self, result)
3157    }
3158}
3159impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3160    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3161        Response::<proto::JoinRoom>::send(self, result)
3162    }
3163}
3164
3165async fn join_channel_internal(
3166    channel_id: ChannelId,
3167    response: Box<impl JoinChannelInternalResponse>,
3168    session: Session,
3169) -> Result<()> {
3170    let joined_room = {
3171        let mut db = session.db().await;
3172        // If zed quits without leaving the room, and the user re-opens zed before the
3173        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3174        // room they were in.
3175        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3176            tracing::info!(
3177                stale_connection_id = %connection,
3178                "cleaning up stale connection",
3179            );
3180            drop(db);
3181            leave_room_for_session(&session, connection).await?;
3182            db = session.db().await;
3183        }
3184
3185        let (joined_room, membership_updated, role) = db
3186            .join_channel(channel_id, session.user_id(), session.connection_id)
3187            .await?;
3188
3189        let live_kit_connection_info =
3190            session
3191                .app_state
3192                .livekit_client
3193                .as_ref()
3194                .and_then(|live_kit| {
3195                    let (can_publish, token) = if role == ChannelRole::Guest {
3196                        (
3197                            false,
3198                            live_kit
3199                                .guest_token(
3200                                    &joined_room.room.livekit_room,
3201                                    &session.user_id().to_string(),
3202                                )
3203                                .trace_err()?,
3204                        )
3205                    } else {
3206                        (
3207                            true,
3208                            live_kit
3209                                .room_token(
3210                                    &joined_room.room.livekit_room,
3211                                    &session.user_id().to_string(),
3212                                )
3213                                .trace_err()?,
3214                        )
3215                    };
3216
3217                    Some(LiveKitConnectionInfo {
3218                        server_url: live_kit.url().into(),
3219                        token,
3220                        can_publish,
3221                    })
3222                });
3223
3224        response.send(proto::JoinRoomResponse {
3225            room: Some(joined_room.room.clone()),
3226            channel_id: joined_room
3227                .channel
3228                .as_ref()
3229                .map(|channel| channel.id.to_proto()),
3230            live_kit_connection_info,
3231        })?;
3232
3233        let mut connection_pool = session.connection_pool().await;
3234        if let Some(membership_updated) = membership_updated {
3235            notify_membership_updated(
3236                &mut connection_pool,
3237                membership_updated,
3238                session.user_id(),
3239                &session.peer,
3240            );
3241        }
3242
3243        room_updated(&joined_room.room, &session.peer);
3244
3245        joined_room
3246    };
3247
3248    channel_updated(
3249        &joined_room
3250            .channel
3251            .ok_or_else(|| anyhow!("channel not returned"))?,
3252        &joined_room.room,
3253        &session.peer,
3254        &*session.connection_pool().await,
3255    );
3256
3257    update_user_contacts(session.user_id(), &session).await?;
3258    Ok(())
3259}
3260
3261/// Start editing the channel notes
3262async fn join_channel_buffer(
3263    request: proto::JoinChannelBuffer,
3264    response: Response<proto::JoinChannelBuffer>,
3265    session: Session,
3266) -> Result<()> {
3267    let db = session.db().await;
3268    let channel_id = ChannelId::from_proto(request.channel_id);
3269
3270    let open_response = db
3271        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3272        .await?;
3273
3274    let collaborators = open_response.collaborators.clone();
3275    response.send(open_response)?;
3276
3277    let update = UpdateChannelBufferCollaborators {
3278        channel_id: channel_id.to_proto(),
3279        collaborators: collaborators.clone(),
3280    };
3281    channel_buffer_updated(
3282        session.connection_id,
3283        collaborators
3284            .iter()
3285            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3286        &update,
3287        &session.peer,
3288    );
3289
3290    Ok(())
3291}
3292
3293/// Edit the channel notes
3294async fn update_channel_buffer(
3295    request: proto::UpdateChannelBuffer,
3296    session: Session,
3297) -> Result<()> {
3298    let db = session.db().await;
3299    let channel_id = ChannelId::from_proto(request.channel_id);
3300
3301    let (collaborators, epoch, version) = db
3302        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3303        .await?;
3304
3305    channel_buffer_updated(
3306        session.connection_id,
3307        collaborators.clone(),
3308        &proto::UpdateChannelBuffer {
3309            channel_id: channel_id.to_proto(),
3310            operations: request.operations,
3311        },
3312        &session.peer,
3313    );
3314
3315    let pool = &*session.connection_pool().await;
3316
3317    let non_collaborators =
3318        pool.channel_connection_ids(channel_id)
3319            .filter_map(|(connection_id, _)| {
3320                if collaborators.contains(&connection_id) {
3321                    None
3322                } else {
3323                    Some(connection_id)
3324                }
3325            });
3326
3327    broadcast(None, non_collaborators, |peer_id| {
3328        session.peer.send(
3329            peer_id,
3330            proto::UpdateChannels {
3331                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3332                    channel_id: channel_id.to_proto(),
3333                    epoch: epoch as u64,
3334                    version: version.clone(),
3335                }],
3336                ..Default::default()
3337            },
3338        )
3339    });
3340
3341    Ok(())
3342}
3343
3344/// Rejoin the channel notes after a connection blip
3345async fn rejoin_channel_buffers(
3346    request: proto::RejoinChannelBuffers,
3347    response: Response<proto::RejoinChannelBuffers>,
3348    session: Session,
3349) -> Result<()> {
3350    let db = session.db().await;
3351    let buffers = db
3352        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3353        .await?;
3354
3355    for rejoined_buffer in &buffers {
3356        let collaborators_to_notify = rejoined_buffer
3357            .buffer
3358            .collaborators
3359            .iter()
3360            .filter_map(|c| Some(c.peer_id?.into()));
3361        channel_buffer_updated(
3362            session.connection_id,
3363            collaborators_to_notify,
3364            &proto::UpdateChannelBufferCollaborators {
3365                channel_id: rejoined_buffer.buffer.channel_id,
3366                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3367            },
3368            &session.peer,
3369        );
3370    }
3371
3372    response.send(proto::RejoinChannelBuffersResponse {
3373        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3374    })?;
3375
3376    Ok(())
3377}
3378
3379/// Stop editing the channel notes
3380async fn leave_channel_buffer(
3381    request: proto::LeaveChannelBuffer,
3382    response: Response<proto::LeaveChannelBuffer>,
3383    session: Session,
3384) -> Result<()> {
3385    let db = session.db().await;
3386    let channel_id = ChannelId::from_proto(request.channel_id);
3387
3388    let left_buffer = db
3389        .leave_channel_buffer(channel_id, session.connection_id)
3390        .await?;
3391
3392    response.send(Ack {})?;
3393
3394    channel_buffer_updated(
3395        session.connection_id,
3396        left_buffer.connections,
3397        &proto::UpdateChannelBufferCollaborators {
3398            channel_id: channel_id.to_proto(),
3399            collaborators: left_buffer.collaborators,
3400        },
3401        &session.peer,
3402    );
3403
3404    Ok(())
3405}
3406
3407fn channel_buffer_updated<T: EnvelopedMessage>(
3408    sender_id: ConnectionId,
3409    collaborators: impl IntoIterator<Item = ConnectionId>,
3410    message: &T,
3411    peer: &Peer,
3412) {
3413    broadcast(Some(sender_id), collaborators, |peer_id| {
3414        peer.send(peer_id, message.clone())
3415    });
3416}
3417
3418fn send_notifications(
3419    connection_pool: &ConnectionPool,
3420    peer: &Peer,
3421    notifications: db::NotificationBatch,
3422) {
3423    for (user_id, notification) in notifications {
3424        for connection_id in connection_pool.user_connection_ids(user_id) {
3425            if let Err(error) = peer.send(
3426                connection_id,
3427                proto::AddNotification {
3428                    notification: Some(notification.clone()),
3429                },
3430            ) {
3431                tracing::error!(
3432                    "failed to send notification to {:?} {}",
3433                    connection_id,
3434                    error
3435                );
3436            }
3437        }
3438    }
3439}
3440
3441/// Send a message to the channel
3442async fn send_channel_message(
3443    request: proto::SendChannelMessage,
3444    response: Response<proto::SendChannelMessage>,
3445    session: Session,
3446) -> Result<()> {
3447    // Validate the message body.
3448    let body = request.body.trim().to_string();
3449    if body.len() > MAX_MESSAGE_LEN {
3450        return Err(anyhow!("message is too long"))?;
3451    }
3452    if body.is_empty() {
3453        return Err(anyhow!("message can't be blank"))?;
3454    }
3455
3456    // TODO: adjust mentions if body is trimmed
3457
3458    let timestamp = OffsetDateTime::now_utc();
3459    let nonce = request
3460        .nonce
3461        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3462
3463    let channel_id = ChannelId::from_proto(request.channel_id);
3464    let CreatedChannelMessage {
3465        message_id,
3466        participant_connection_ids,
3467        notifications,
3468    } = session
3469        .db()
3470        .await
3471        .create_channel_message(
3472            channel_id,
3473            session.user_id(),
3474            &body,
3475            &request.mentions,
3476            timestamp,
3477            nonce.clone().into(),
3478            request.reply_to_message_id.map(MessageId::from_proto),
3479        )
3480        .await?;
3481
3482    let message = proto::ChannelMessage {
3483        sender_id: session.user_id().to_proto(),
3484        id: message_id.to_proto(),
3485        body,
3486        mentions: request.mentions,
3487        timestamp: timestamp.unix_timestamp() as u64,
3488        nonce: Some(nonce),
3489        reply_to_message_id: request.reply_to_message_id,
3490        edited_at: None,
3491    };
3492    broadcast(
3493        Some(session.connection_id),
3494        participant_connection_ids.clone(),
3495        |connection| {
3496            session.peer.send(
3497                connection,
3498                proto::ChannelMessageSent {
3499                    channel_id: channel_id.to_proto(),
3500                    message: Some(message.clone()),
3501                },
3502            )
3503        },
3504    );
3505    response.send(proto::SendChannelMessageResponse {
3506        message: Some(message),
3507    })?;
3508
3509    let pool = &*session.connection_pool().await;
3510    let non_participants =
3511        pool.channel_connection_ids(channel_id)
3512            .filter_map(|(connection_id, _)| {
3513                if participant_connection_ids.contains(&connection_id) {
3514                    None
3515                } else {
3516                    Some(connection_id)
3517                }
3518            });
3519    broadcast(None, non_participants, |peer_id| {
3520        session.peer.send(
3521            peer_id,
3522            proto::UpdateChannels {
3523                latest_channel_message_ids: vec![proto::ChannelMessageId {
3524                    channel_id: channel_id.to_proto(),
3525                    message_id: message_id.to_proto(),
3526                }],
3527                ..Default::default()
3528            },
3529        )
3530    });
3531    send_notifications(pool, &session.peer, notifications);
3532
3533    Ok(())
3534}
3535
3536/// Delete a channel message
3537async fn remove_channel_message(
3538    request: proto::RemoveChannelMessage,
3539    response: Response<proto::RemoveChannelMessage>,
3540    session: Session,
3541) -> Result<()> {
3542    let channel_id = ChannelId::from_proto(request.channel_id);
3543    let message_id = MessageId::from_proto(request.message_id);
3544    let (connection_ids, existing_notification_ids) = session
3545        .db()
3546        .await
3547        .remove_channel_message(channel_id, message_id, session.user_id())
3548        .await?;
3549
3550    broadcast(
3551        Some(session.connection_id),
3552        connection_ids,
3553        move |connection| {
3554            session.peer.send(connection, request.clone())?;
3555
3556            for notification_id in &existing_notification_ids {
3557                session.peer.send(
3558                    connection,
3559                    proto::DeleteNotification {
3560                        notification_id: (*notification_id).to_proto(),
3561                    },
3562                )?;
3563            }
3564
3565            Ok(())
3566        },
3567    );
3568    response.send(proto::Ack {})?;
3569    Ok(())
3570}
3571
3572async fn update_channel_message(
3573    request: proto::UpdateChannelMessage,
3574    response: Response<proto::UpdateChannelMessage>,
3575    session: Session,
3576) -> Result<()> {
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578    let message_id = MessageId::from_proto(request.message_id);
3579    let updated_at = OffsetDateTime::now_utc();
3580    let UpdatedChannelMessage {
3581        message_id,
3582        participant_connection_ids,
3583        notifications,
3584        reply_to_message_id,
3585        timestamp,
3586        deleted_mention_notification_ids,
3587        updated_mention_notifications,
3588    } = session
3589        .db()
3590        .await
3591        .update_channel_message(
3592            channel_id,
3593            message_id,
3594            session.user_id(),
3595            request.body.as_str(),
3596            &request.mentions,
3597            updated_at,
3598        )
3599        .await?;
3600
3601    let nonce = request
3602        .nonce
3603        .clone()
3604        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3605
3606    let message = proto::ChannelMessage {
3607        sender_id: session.user_id().to_proto(),
3608        id: message_id.to_proto(),
3609        body: request.body.clone(),
3610        mentions: request.mentions.clone(),
3611        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3612        nonce: Some(nonce),
3613        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3614        edited_at: Some(updated_at.unix_timestamp() as u64),
3615    };
3616
3617    response.send(proto::Ack {})?;
3618
3619    let pool = &*session.connection_pool().await;
3620    broadcast(
3621        Some(session.connection_id),
3622        participant_connection_ids,
3623        |connection| {
3624            session.peer.send(
3625                connection,
3626                proto::ChannelMessageUpdate {
3627                    channel_id: channel_id.to_proto(),
3628                    message: Some(message.clone()),
3629                },
3630            )?;
3631
3632            for notification_id in &deleted_mention_notification_ids {
3633                session.peer.send(
3634                    connection,
3635                    proto::DeleteNotification {
3636                        notification_id: (*notification_id).to_proto(),
3637                    },
3638                )?;
3639            }
3640
3641            for notification in &updated_mention_notifications {
3642                session.peer.send(
3643                    connection,
3644                    proto::UpdateNotification {
3645                        notification: Some(notification.clone()),
3646                    },
3647                )?;
3648            }
3649
3650            Ok(())
3651        },
3652    );
3653
3654    send_notifications(pool, &session.peer, notifications);
3655
3656    Ok(())
3657}
3658
3659/// Mark a channel message as read
3660async fn acknowledge_channel_message(
3661    request: proto::AckChannelMessage,
3662    session: Session,
3663) -> Result<()> {
3664    let channel_id = ChannelId::from_proto(request.channel_id);
3665    let message_id = MessageId::from_proto(request.message_id);
3666    let notifications = session
3667        .db()
3668        .await
3669        .observe_channel_message(channel_id, session.user_id(), message_id)
3670        .await?;
3671    send_notifications(
3672        &*session.connection_pool().await,
3673        &session.peer,
3674        notifications,
3675    );
3676    Ok(())
3677}
3678
3679/// Mark a buffer version as synced
3680async fn acknowledge_buffer_version(
3681    request: proto::AckBufferOperation,
3682    session: Session,
3683) -> Result<()> {
3684    let buffer_id = BufferId::from_proto(request.buffer_id);
3685    session
3686        .db()
3687        .await
3688        .observe_buffer_version(
3689            buffer_id,
3690            session.user_id(),
3691            request.epoch as i32,
3692            &request.version,
3693        )
3694        .await?;
3695    Ok(())
3696}
3697
3698async fn count_language_model_tokens(
3699    request: proto::CountLanguageModelTokens,
3700    response: Response<proto::CountLanguageModelTokens>,
3701    session: Session,
3702    config: &Config,
3703) -> Result<()> {
3704    authorize_access_to_legacy_llm_endpoints(&session).await?;
3705
3706    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3707        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3708        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3709    };
3710
3711    session
3712        .app_state
3713        .rate_limiter
3714        .check(&*rate_limit, session.user_id())
3715        .await?;
3716
3717    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3718        Some(proto::LanguageModelProvider::Google) => {
3719            let api_key = config
3720                .google_ai_api_key
3721                .as_ref()
3722                .context("no Google AI API key configured on the server")?;
3723            google_ai::count_tokens(
3724                session.http_client.as_ref(),
3725                google_ai::API_URL,
3726                api_key,
3727                serde_json::from_str(&request.request)?,
3728            )
3729            .await?
3730        }
3731        _ => return Err(anyhow!("unsupported provider"))?,
3732    };
3733
3734    response.send(proto::CountLanguageModelTokensResponse {
3735        token_count: result.total_tokens as u32,
3736    })?;
3737
3738    Ok(())
3739}
3740
3741struct ZedProCountLanguageModelTokensRateLimit;
3742
3743impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3744    fn capacity(&self) -> usize {
3745        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3746            .ok()
3747            .and_then(|v| v.parse().ok())
3748            .unwrap_or(600) // Picked arbitrarily
3749    }
3750
3751    fn refill_duration(&self) -> chrono::Duration {
3752        chrono::Duration::hours(1)
3753    }
3754
3755    fn db_name(&self) -> &'static str {
3756        "zed-pro:count-language-model-tokens"
3757    }
3758}
3759
3760struct FreeCountLanguageModelTokensRateLimit;
3761
3762impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3763    fn capacity(&self) -> usize {
3764        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3765            .ok()
3766            .and_then(|v| v.parse().ok())
3767            .unwrap_or(600 / 10) // Picked arbitrarily
3768    }
3769
3770    fn refill_duration(&self) -> chrono::Duration {
3771        chrono::Duration::hours(1)
3772    }
3773
3774    fn db_name(&self) -> &'static str {
3775        "free:count-language-model-tokens"
3776    }
3777}
3778
3779struct ZedProComputeEmbeddingsRateLimit;
3780
3781impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3782    fn capacity(&self) -> usize {
3783        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3784            .ok()
3785            .and_then(|v| v.parse().ok())
3786            .unwrap_or(5000) // Picked arbitrarily
3787    }
3788
3789    fn refill_duration(&self) -> chrono::Duration {
3790        chrono::Duration::hours(1)
3791    }
3792
3793    fn db_name(&self) -> &'static str {
3794        "zed-pro:compute-embeddings"
3795    }
3796}
3797
3798struct FreeComputeEmbeddingsRateLimit;
3799
3800impl RateLimit for FreeComputeEmbeddingsRateLimit {
3801    fn capacity(&self) -> usize {
3802        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3803            .ok()
3804            .and_then(|v| v.parse().ok())
3805            .unwrap_or(5000 / 10) // Picked arbitrarily
3806    }
3807
3808    fn refill_duration(&self) -> chrono::Duration {
3809        chrono::Duration::hours(1)
3810    }
3811
3812    fn db_name(&self) -> &'static str {
3813        "free:compute-embeddings"
3814    }
3815}
3816
3817async fn compute_embeddings(
3818    request: proto::ComputeEmbeddings,
3819    response: Response<proto::ComputeEmbeddings>,
3820    session: Session,
3821    api_key: Option<Arc<str>>,
3822) -> Result<()> {
3823    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3824    authorize_access_to_legacy_llm_endpoints(&session).await?;
3825
3826    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3827        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3828        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3829    };
3830
3831    session
3832        .app_state
3833        .rate_limiter
3834        .check(&*rate_limit, session.user_id())
3835        .await?;
3836
3837    let embeddings = match request.model.as_str() {
3838        "openai/text-embedding-3-small" => {
3839            open_ai::embed(
3840                session.http_client.as_ref(),
3841                OPEN_AI_API_URL,
3842                &api_key,
3843                OpenAiEmbeddingModel::TextEmbedding3Small,
3844                request.texts.iter().map(|text| text.as_str()),
3845            )
3846            .await?
3847        }
3848        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3849    };
3850
3851    let embeddings = request
3852        .texts
3853        .iter()
3854        .map(|text| {
3855            let mut hasher = sha2::Sha256::new();
3856            hasher.update(text.as_bytes());
3857            let result = hasher.finalize();
3858            result.to_vec()
3859        })
3860        .zip(
3861            embeddings
3862                .data
3863                .into_iter()
3864                .map(|embedding| embedding.embedding),
3865        )
3866        .collect::<HashMap<_, _>>();
3867
3868    let db = session.db().await;
3869    db.save_embeddings(&request.model, &embeddings)
3870        .await
3871        .context("failed to save embeddings")
3872        .trace_err();
3873
3874    response.send(proto::ComputeEmbeddingsResponse {
3875        embeddings: embeddings
3876            .into_iter()
3877            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3878            .collect(),
3879    })?;
3880    Ok(())
3881}
3882
3883async fn get_cached_embeddings(
3884    request: proto::GetCachedEmbeddings,
3885    response: Response<proto::GetCachedEmbeddings>,
3886    session: Session,
3887) -> Result<()> {
3888    authorize_access_to_legacy_llm_endpoints(&session).await?;
3889
3890    let db = session.db().await;
3891    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3892
3893    response.send(proto::GetCachedEmbeddingsResponse {
3894        embeddings: embeddings
3895            .into_iter()
3896            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3897            .collect(),
3898    })?;
3899    Ok(())
3900}
3901
3902/// This is leftover from before the LLM service.
3903///
3904/// The endpoints protected by this check will be moved there eventually.
3905async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3906    if session.is_staff() {
3907        Ok(())
3908    } else {
3909        Err(anyhow!("permission denied"))?
3910    }
3911}
3912
3913/// Get a Supermaven API key for the user
3914async fn get_supermaven_api_key(
3915    _request: proto::GetSupermavenApiKey,
3916    response: Response<proto::GetSupermavenApiKey>,
3917    session: Session,
3918) -> Result<()> {
3919    let user_id: String = session.user_id().to_string();
3920    if !session.is_staff() {
3921        return Err(anyhow!("supermaven not enabled for this account"))?;
3922    }
3923
3924    let email = session
3925        .email()
3926        .ok_or_else(|| anyhow!("user must have an email"))?;
3927
3928    let supermaven_admin_api = session
3929        .supermaven_client
3930        .as_ref()
3931        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3932
3933    let result = supermaven_admin_api
3934        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3935        .await?;
3936
3937    response.send(proto::GetSupermavenApiKeyResponse {
3938        api_key: result.api_key,
3939    })?;
3940
3941    Ok(())
3942}
3943
3944/// Start receiving chat updates for a channel
3945async fn join_channel_chat(
3946    request: proto::JoinChannelChat,
3947    response: Response<proto::JoinChannelChat>,
3948    session: Session,
3949) -> Result<()> {
3950    let channel_id = ChannelId::from_proto(request.channel_id);
3951
3952    let db = session.db().await;
3953    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3954        .await?;
3955    let messages = db
3956        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3957        .await?;
3958    response.send(proto::JoinChannelChatResponse {
3959        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3960        messages,
3961    })?;
3962    Ok(())
3963}
3964
3965/// Stop receiving chat updates for a channel
3966async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3967    let channel_id = ChannelId::from_proto(request.channel_id);
3968    session
3969        .db()
3970        .await
3971        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3972        .await?;
3973    Ok(())
3974}
3975
3976/// Retrieve the chat history for a channel
3977async fn get_channel_messages(
3978    request: proto::GetChannelMessages,
3979    response: Response<proto::GetChannelMessages>,
3980    session: Session,
3981) -> Result<()> {
3982    let channel_id = ChannelId::from_proto(request.channel_id);
3983    let messages = session
3984        .db()
3985        .await
3986        .get_channel_messages(
3987            channel_id,
3988            session.user_id(),
3989            MESSAGE_COUNT_PER_PAGE,
3990            Some(MessageId::from_proto(request.before_message_id)),
3991        )
3992        .await?;
3993    response.send(proto::GetChannelMessagesResponse {
3994        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3995        messages,
3996    })?;
3997    Ok(())
3998}
3999
4000/// Retrieve specific chat messages
4001async fn get_channel_messages_by_id(
4002    request: proto::GetChannelMessagesById,
4003    response: Response<proto::GetChannelMessagesById>,
4004    session: Session,
4005) -> Result<()> {
4006    let message_ids = request
4007        .message_ids
4008        .iter()
4009        .map(|id| MessageId::from_proto(*id))
4010        .collect::<Vec<_>>();
4011    let messages = session
4012        .db()
4013        .await
4014        .get_channel_messages_by_id(session.user_id(), &message_ids)
4015        .await?;
4016    response.send(proto::GetChannelMessagesResponse {
4017        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4018        messages,
4019    })?;
4020    Ok(())
4021}
4022
4023/// Retrieve the current users notifications
4024async fn get_notifications(
4025    request: proto::GetNotifications,
4026    response: Response<proto::GetNotifications>,
4027    session: Session,
4028) -> Result<()> {
4029    let notifications = session
4030        .db()
4031        .await
4032        .get_notifications(
4033            session.user_id(),
4034            NOTIFICATION_COUNT_PER_PAGE,
4035            request.before_id.map(db::NotificationId::from_proto),
4036        )
4037        .await?;
4038    response.send(proto::GetNotificationsResponse {
4039        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4040        notifications,
4041    })?;
4042    Ok(())
4043}
4044
4045/// Mark notifications as read
4046async fn mark_notification_as_read(
4047    request: proto::MarkNotificationRead,
4048    response: Response<proto::MarkNotificationRead>,
4049    session: Session,
4050) -> Result<()> {
4051    let database = &session.db().await;
4052    let notifications = database
4053        .mark_notification_as_read_by_id(
4054            session.user_id(),
4055            NotificationId::from_proto(request.notification_id),
4056        )
4057        .await?;
4058    send_notifications(
4059        &*session.connection_pool().await,
4060        &session.peer,
4061        notifications,
4062    );
4063    response.send(proto::Ack {})?;
4064    Ok(())
4065}
4066
4067/// Get the current users information
4068async fn get_private_user_info(
4069    _request: proto::GetPrivateUserInfo,
4070    response: Response<proto::GetPrivateUserInfo>,
4071    session: Session,
4072) -> Result<()> {
4073    let db = session.db().await;
4074
4075    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4076    let user = db
4077        .get_user_by_id(session.user_id())
4078        .await?
4079        .ok_or_else(|| anyhow!("user not found"))?;
4080    let flags = db.get_user_flags(session.user_id()).await?;
4081
4082    response.send(proto::GetPrivateUserInfoResponse {
4083        metrics_id,
4084        staff: user.admin,
4085        flags,
4086        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4087    })?;
4088    Ok(())
4089}
4090
4091/// Accept the terms of service (tos) on behalf of the current user
4092async fn accept_terms_of_service(
4093    _request: proto::AcceptTermsOfService,
4094    response: Response<proto::AcceptTermsOfService>,
4095    session: Session,
4096) -> Result<()> {
4097    let db = session.db().await;
4098
4099    let accepted_tos_at = Utc::now();
4100    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4101        .await?;
4102
4103    response.send(proto::AcceptTermsOfServiceResponse {
4104        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4105    })?;
4106    Ok(())
4107}
4108
4109/// The minimum account age an account must have in order to use the LLM service.
4110pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4111
4112async fn get_llm_api_token(
4113    _request: proto::GetLlmToken,
4114    response: Response<proto::GetLlmToken>,
4115    session: Session,
4116) -> Result<()> {
4117    let db = session.db().await;
4118
4119    let flags = db.get_user_flags(session.user_id()).await?;
4120    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4121
4122    if !session.is_staff() && !has_language_models_feature_flag {
4123        Err(anyhow!("permission denied"))?
4124    }
4125
4126    let user_id = session.user_id();
4127    let user = db
4128        .get_user_by_id(user_id)
4129        .await?
4130        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4131
4132    if user.accepted_tos_at.is_none() {
4133        Err(anyhow!("terms of service not accepted"))?
4134    }
4135
4136    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4137    let billing_preferences = db.get_billing_preferences(user.id).await?;
4138
4139    let token = LlmTokenClaims::create(
4140        &user,
4141        session.is_staff(),
4142        billing_preferences,
4143        &flags,
4144        has_llm_subscription,
4145        session.current_plan(&db).await?,
4146        session.system_id.clone(),
4147        &session.app_state.config,
4148    )?;
4149    response.send(proto::GetLlmTokenResponse { token })?;
4150    Ok(())
4151}
4152
4153fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4154    let message = match message {
4155        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4156        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4157        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4158        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4159        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4160            code: frame.code.into(),
4161            reason: frame.reason,
4162        })),
4163        // We should never receive a frame while reading the message, according
4164        // to the `tungstenite` maintainers:
4165        //
4166        // > It cannot occur when you read messages from the WebSocket, but it
4167        // > can be used when you want to send the raw frames (e.g. you want to
4168        // > send the frames to the WebSocket without composing the full message first).
4169        // >
4170        // > — https://github.com/snapview/tungstenite-rs/issues/268
4171        TungsteniteMessage::Frame(_) => {
4172            bail!("received an unexpected frame while reading the message")
4173        }
4174    };
4175
4176    Ok(message)
4177}
4178
4179fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4180    match message {
4181        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4182        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4183        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4184        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4185        AxumMessage::Close(frame) => {
4186            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4187                code: frame.code.into(),
4188                reason: frame.reason,
4189            }))
4190        }
4191    }
4192}
4193
4194fn notify_membership_updated(
4195    connection_pool: &mut ConnectionPool,
4196    result: MembershipUpdated,
4197    user_id: UserId,
4198    peer: &Peer,
4199) {
4200    for membership in &result.new_channels.channel_memberships {
4201        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4202    }
4203    for channel_id in &result.removed_channels {
4204        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4205    }
4206
4207    let user_channels_update = proto::UpdateUserChannels {
4208        channel_memberships: result
4209            .new_channels
4210            .channel_memberships
4211            .iter()
4212            .map(|cm| proto::ChannelMembership {
4213                channel_id: cm.channel_id.to_proto(),
4214                role: cm.role.into(),
4215            })
4216            .collect(),
4217        ..Default::default()
4218    };
4219
4220    let mut update = build_channels_update(result.new_channels);
4221    update.delete_channels = result
4222        .removed_channels
4223        .into_iter()
4224        .map(|id| id.to_proto())
4225        .collect();
4226    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4227
4228    for connection_id in connection_pool.user_connection_ids(user_id) {
4229        peer.send(connection_id, user_channels_update.clone())
4230            .trace_err();
4231        peer.send(connection_id, update.clone()).trace_err();
4232    }
4233}
4234
4235fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4236    proto::UpdateUserChannels {
4237        channel_memberships: channels
4238            .channel_memberships
4239            .iter()
4240            .map(|m| proto::ChannelMembership {
4241                channel_id: m.channel_id.to_proto(),
4242                role: m.role.into(),
4243            })
4244            .collect(),
4245        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4246        observed_channel_message_id: channels.observed_channel_messages.clone(),
4247    }
4248}
4249
4250fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4251    let mut update = proto::UpdateChannels::default();
4252
4253    for channel in channels.channels {
4254        update.channels.push(channel.to_proto());
4255    }
4256
4257    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4258    update.latest_channel_message_ids = channels.latest_channel_messages;
4259
4260    for (channel_id, participants) in channels.channel_participants {
4261        update
4262            .channel_participants
4263            .push(proto::ChannelParticipants {
4264                channel_id: channel_id.to_proto(),
4265                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4266            });
4267    }
4268
4269    for channel in channels.invited_channels {
4270        update.channel_invitations.push(channel.to_proto());
4271    }
4272
4273    update
4274}
4275
4276fn build_initial_contacts_update(
4277    contacts: Vec<db::Contact>,
4278    pool: &ConnectionPool,
4279) -> proto::UpdateContacts {
4280    let mut update = proto::UpdateContacts::default();
4281
4282    for contact in contacts {
4283        match contact {
4284            db::Contact::Accepted { user_id, busy } => {
4285                update.contacts.push(contact_for_user(user_id, busy, pool));
4286            }
4287            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4288            db::Contact::Incoming { user_id } => {
4289                update
4290                    .incoming_requests
4291                    .push(proto::IncomingContactRequest {
4292                        requester_id: user_id.to_proto(),
4293                    })
4294            }
4295        }
4296    }
4297
4298    update
4299}
4300
4301fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4302    proto::Contact {
4303        user_id: user_id.to_proto(),
4304        online: pool.is_user_online(user_id),
4305        busy,
4306    }
4307}
4308
4309fn room_updated(room: &proto::Room, peer: &Peer) {
4310    broadcast(
4311        None,
4312        room.participants
4313            .iter()
4314            .filter_map(|participant| Some(participant.peer_id?.into())),
4315        |peer_id| {
4316            peer.send(
4317                peer_id,
4318                proto::RoomUpdated {
4319                    room: Some(room.clone()),
4320                },
4321            )
4322        },
4323    );
4324}
4325
4326fn channel_updated(
4327    channel: &db::channel::Model,
4328    room: &proto::Room,
4329    peer: &Peer,
4330    pool: &ConnectionPool,
4331) {
4332    let participants = room
4333        .participants
4334        .iter()
4335        .map(|p| p.user_id)
4336        .collect::<Vec<_>>();
4337
4338    broadcast(
4339        None,
4340        pool.channel_connection_ids(channel.root_id())
4341            .filter_map(|(channel_id, role)| {
4342                role.can_see_channel(channel.visibility)
4343                    .then_some(channel_id)
4344            }),
4345        |peer_id| {
4346            peer.send(
4347                peer_id,
4348                proto::UpdateChannels {
4349                    channel_participants: vec![proto::ChannelParticipants {
4350                        channel_id: channel.id.to_proto(),
4351                        participant_user_ids: participants.clone(),
4352                    }],
4353                    ..Default::default()
4354                },
4355            )
4356        },
4357    );
4358}
4359
4360async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4361    let db = session.db().await;
4362
4363    let contacts = db.get_contacts(user_id).await?;
4364    let busy = db.is_user_busy(user_id).await?;
4365
4366    let pool = session.connection_pool().await;
4367    let updated_contact = contact_for_user(user_id, busy, &pool);
4368    for contact in contacts {
4369        if let db::Contact::Accepted {
4370            user_id: contact_user_id,
4371            ..
4372        } = contact
4373        {
4374            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4375                session
4376                    .peer
4377                    .send(
4378                        contact_conn_id,
4379                        proto::UpdateContacts {
4380                            contacts: vec![updated_contact.clone()],
4381                            remove_contacts: Default::default(),
4382                            incoming_requests: Default::default(),
4383                            remove_incoming_requests: Default::default(),
4384                            outgoing_requests: Default::default(),
4385                            remove_outgoing_requests: Default::default(),
4386                        },
4387                    )
4388                    .trace_err();
4389            }
4390        }
4391    }
4392    Ok(())
4393}
4394
4395async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4396    let mut contacts_to_update = HashSet::default();
4397
4398    let room_id;
4399    let canceled_calls_to_user_ids;
4400    let livekit_room;
4401    let delete_livekit_room;
4402    let room;
4403    let channel;
4404
4405    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4406        contacts_to_update.insert(session.user_id());
4407
4408        for project in left_room.left_projects.values() {
4409            project_left(project, session);
4410        }
4411
4412        room_id = RoomId::from_proto(left_room.room.id);
4413        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4414        livekit_room = mem::take(&mut left_room.room.livekit_room);
4415        delete_livekit_room = left_room.deleted;
4416        room = mem::take(&mut left_room.room);
4417        channel = mem::take(&mut left_room.channel);
4418
4419        room_updated(&room, &session.peer);
4420    } else {
4421        return Ok(());
4422    }
4423
4424    if let Some(channel) = channel {
4425        channel_updated(
4426            &channel,
4427            &room,
4428            &session.peer,
4429            &*session.connection_pool().await,
4430        );
4431    }
4432
4433    {
4434        let pool = session.connection_pool().await;
4435        for canceled_user_id in canceled_calls_to_user_ids {
4436            for connection_id in pool.user_connection_ids(canceled_user_id) {
4437                session
4438                    .peer
4439                    .send(
4440                        connection_id,
4441                        proto::CallCanceled {
4442                            room_id: room_id.to_proto(),
4443                        },
4444                    )
4445                    .trace_err();
4446            }
4447            contacts_to_update.insert(canceled_user_id);
4448        }
4449    }
4450
4451    for contact_user_id in contacts_to_update {
4452        update_user_contacts(contact_user_id, session).await?;
4453    }
4454
4455    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4456        live_kit
4457            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4458            .await
4459            .trace_err();
4460
4461        if delete_livekit_room {
4462            live_kit.delete_room(livekit_room).await.trace_err();
4463        }
4464    }
4465
4466    Ok(())
4467}
4468
4469async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4470    let left_channel_buffers = session
4471        .db()
4472        .await
4473        .leave_channel_buffers(session.connection_id)
4474        .await?;
4475
4476    for left_buffer in left_channel_buffers {
4477        channel_buffer_updated(
4478            session.connection_id,
4479            left_buffer.connections,
4480            &proto::UpdateChannelBufferCollaborators {
4481                channel_id: left_buffer.channel_id.to_proto(),
4482                collaborators: left_buffer.collaborators,
4483            },
4484            &session.peer,
4485        );
4486    }
4487
4488    Ok(())
4489}
4490
4491fn project_left(project: &db::LeftProject, session: &Session) {
4492    for connection_id in &project.connection_ids {
4493        if project.should_unshare {
4494            session
4495                .peer
4496                .send(
4497                    *connection_id,
4498                    proto::UnshareProject {
4499                        project_id: project.id.to_proto(),
4500                    },
4501                )
4502                .trace_err();
4503        } else {
4504            session
4505                .peer
4506                .send(
4507                    *connection_id,
4508                    proto::RemoveProjectCollaborator {
4509                        project_id: project.id.to_proto(),
4510                        peer_id: Some(session.connection_id.into()),
4511                    },
4512                )
4513                .trace_err();
4514        }
4515    }
4516}
4517
4518pub trait ResultExt {
4519    type Ok;
4520
4521    fn trace_err(self) -> Option<Self::Ok>;
4522}
4523
4524impl<T, E> ResultExt for Result<T, E>
4525where
4526    E: std::fmt::Debug,
4527{
4528    type Ok = T;
4529
4530    #[track_caller]
4531    fn trace_err(self) -> Option<T> {
4532        match self {
4533            Ok(value) => Some(value),
4534            Err(error) => {
4535                tracing::error!("{:?}", error);
4536                None
4537            }
4538        }
4539    }
4540}