rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::stripe_client::StripeCustomerId;
   9use crate::{
  10    AppState, Error, Result, auth,
  11    db::{
  12        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  13        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  14        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  15        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  16    },
  17    executor::Executor,
  18};
  19use anyhow::{Context as _, anyhow, bail};
  20use async_tungstenite::tungstenite::{
  21    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  22};
  23use axum::{
  24    Extension, Router, TypedHeader,
  25    body::Body,
  26    extract::{
  27        ConnectInfo, WebSocketUpgrade,
  28        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  29    },
  30    headers::{Header, HeaderName},
  31    http::StatusCode,
  32    middleware,
  33    response::IntoResponse,
  34    routing::get,
  35};
  36use chrono::Utc;
  37use collections::{HashMap, HashSet};
  38pub use connection_pool::{ConnectionPool, ZedVersion};
  39use core::fmt::{self, Debug, Formatter};
  40use reqwest_client::ReqwestClient;
  41use rpc::proto::split_repository_update;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  46    stream::FuturesUnordered,
  47};
  48use prometheus::{IntGauge, register_int_gauge};
  49use rpc::{
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        Arc, OnceLock,
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{Semaphore, watch};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    Instrument,
  77    field::{self},
  78    info_span, instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111}
 112
 113impl Principal {
 114    fn user(&self) -> &User {
 115        match self {
 116            Principal::User(user) => user,
 117            Principal::Impersonated { user, .. } => user,
 118        }
 119    }
 120
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132        }
 133    }
 134}
 135
 136#[derive(Clone)]
 137struct Session {
 138    principal: Principal,
 139    connection_id: ConnectionId,
 140    db: Arc<tokio::sync::Mutex<DbHandle>>,
 141    peer: Arc<Peer>,
 142    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    system_id: Option<String>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn is_staff(&self) -> bool {
 173        match &self.principal {
 174            Principal::User(user) => user.admin,
 175            Principal::Impersonated { .. } => true,
 176        }
 177    }
 178
 179    fn user_id(&self) -> UserId {
 180        match &self.principal {
 181            Principal::User(user) => user.id,
 182            Principal::Impersonated { user, .. } => user.id,
 183        }
 184    }
 185
 186    pub fn email(&self) -> Option<String> {
 187        match &self.principal {
 188            Principal::User(user) => user.email_address.clone(),
 189            Principal::Impersonated { user, .. } => user.email_address.clone(),
 190        }
 191    }
 192}
 193
 194impl Debug for Session {
 195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 196        let mut result = f.debug_struct("Session");
 197        match &self.principal {
 198            Principal::User(user) => {
 199                result.field("user", &user.github_login);
 200            }
 201            Principal::Impersonated { user, admin } => {
 202                result.field("user", &user.github_login);
 203                result.field("impersonator", &admin.github_login);
 204            }
 205        }
 206        result.field("connection_id", &self.connection_id).finish()
 207    }
 208}
 209
 210struct DbHandle(Arc<Database>);
 211
 212impl Deref for DbHandle {
 213    type Target = Database;
 214
 215    fn deref(&self) -> &Self::Target {
 216        self.0.as_ref()
 217    }
 218}
 219
 220pub struct Server {
 221    id: parking_lot::Mutex<ServerId>,
 222    peer: Arc<Peer>,
 223    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 224    app_state: Arc<AppState>,
 225    handlers: HashMap<TypeId, MessageHandler>,
 226    teardown: watch::Sender<bool>,
 227}
 228
 229pub(crate) struct ConnectionPoolGuard<'a> {
 230    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 231    _not_send: PhantomData<Rc<()>>,
 232}
 233
 234#[derive(Serialize)]
 235pub struct ServerSnapshot<'a> {
 236    peer: &'a Peer,
 237    #[serde(serialize_with = "serialize_deref")]
 238    connection_pool: ConnectionPoolGuard<'a>,
 239}
 240
 241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 242where
 243    S: Serializer,
 244    T: Deref<Target = U>,
 245    U: Serialize,
 246{
 247    Serialize::serialize(value.deref(), serializer)
 248}
 249
 250impl Server {
 251    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 252        let mut server = Self {
 253            id: parking_lot::Mutex::new(id),
 254            peer: Peer::new(id.0 as u32),
 255            app_state: app_state.clone(),
 256            connection_pool: Default::default(),
 257            handlers: Default::default(),
 258            teardown: watch::channel(false).0,
 259        };
 260
 261        server
 262            .add_request_handler(ping)
 263            .add_request_handler(create_room)
 264            .add_request_handler(join_room)
 265            .add_request_handler(rejoin_room)
 266            .add_request_handler(leave_room)
 267            .add_request_handler(set_room_participant_role)
 268            .add_request_handler(call)
 269            .add_request_handler(cancel_call)
 270            .add_message_handler(decline_call)
 271            .add_request_handler(update_participant_location)
 272            .add_request_handler(share_project)
 273            .add_message_handler(unshare_project)
 274            .add_request_handler(join_project)
 275            .add_message_handler(leave_project)
 276            .add_request_handler(update_project)
 277            .add_request_handler(update_worktree)
 278            .add_request_handler(update_repository)
 279            .add_request_handler(remove_repository)
 280            .add_message_handler(start_language_server)
 281            .add_message_handler(update_language_server)
 282            .add_message_handler(update_diagnostic_summary)
 283            .add_message_handler(update_worktree_settings)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 288            .add_request_handler(forward_find_search_candidates_request)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 294            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 295            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 296            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 297            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 298            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 304            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 305            .add_request_handler(
 306                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 307            )
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 312            .add_request_handler(
 313                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 314            )
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 316            .add_request_handler(
 317                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 318            )
 319            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 320            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 321            .add_request_handler(
 322                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 323            )
 324            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 325            .add_request_handler(
 326                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 327            )
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 334            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 339            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 340            .add_request_handler(
 341                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 342            )
 343            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 344            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 346            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 348            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 349            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 350            .add_message_handler(create_buffer_for_peer)
 351            .add_request_handler(update_buffer)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 357            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 358            .add_message_handler(
 359                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 360            )
 361            .add_request_handler(get_users)
 362            .add_request_handler(fuzzy_search_users)
 363            .add_request_handler(request_contact)
 364            .add_request_handler(remove_contact)
 365            .add_request_handler(respond_to_contact_request)
 366            .add_message_handler(subscribe_to_channels)
 367            .add_request_handler(create_channel)
 368            .add_request_handler(delete_channel)
 369            .add_request_handler(invite_channel_member)
 370            .add_request_handler(remove_channel_member)
 371            .add_request_handler(set_channel_member_role)
 372            .add_request_handler(set_channel_visibility)
 373            .add_request_handler(rename_channel)
 374            .add_request_handler(join_channel_buffer)
 375            .add_request_handler(leave_channel_buffer)
 376            .add_message_handler(update_channel_buffer)
 377            .add_request_handler(rejoin_channel_buffers)
 378            .add_request_handler(get_channel_members)
 379            .add_request_handler(respond_to_channel_invite)
 380            .add_request_handler(join_channel)
 381            .add_request_handler(join_channel_chat)
 382            .add_message_handler(leave_channel_chat)
 383            .add_request_handler(send_channel_message)
 384            .add_request_handler(remove_channel_message)
 385            .add_request_handler(update_channel_message)
 386            .add_request_handler(get_channel_messages)
 387            .add_request_handler(get_channel_messages_by_id)
 388            .add_request_handler(get_notifications)
 389            .add_request_handler(mark_notification_as_read)
 390            .add_request_handler(move_channel)
 391            .add_request_handler(reorder_channel)
 392            .add_request_handler(follow)
 393            .add_message_handler(unfollow)
 394            .add_message_handler(update_followers)
 395            .add_request_handler(get_private_user_info)
 396            .add_request_handler(get_llm_api_token)
 397            .add_request_handler(accept_terms_of_service)
 398            .add_message_handler(acknowledge_channel_message)
 399            .add_message_handler(acknowledge_buffer_version)
 400            .add_request_handler(get_supermaven_api_key)
 401            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 402            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 403            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 404            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 405            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 406            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 407            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 408            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 409            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 410            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 411            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 412            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 413            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 414            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 415            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 416            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 417            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 418            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 419            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 420            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 421            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 422            .add_message_handler(update_context);
 423
 424        Arc::new(server)
 425    }
 426
 427    pub async fn start(&self) -> Result<()> {
 428        let server_id = *self.id.lock();
 429        let app_state = self.app_state.clone();
 430        let peer = self.peer.clone();
 431        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 432        let pool = self.connection_pool.clone();
 433        let livekit_client = self.app_state.livekit_client.clone();
 434
 435        let span = info_span!("start server");
 436        self.app_state.executor.spawn_detached(
 437            async move {
 438                tracing::info!("waiting for cleanup timeout");
 439                timeout.await;
 440                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 441
 442                app_state
 443                    .db
 444                    .delete_stale_channel_chat_participants(
 445                        &app_state.config.zed_environment,
 446                        server_id,
 447                    )
 448                    .await
 449                    .trace_err();
 450
 451                if let Some((room_ids, channel_ids)) = app_state
 452                    .db
 453                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 454                    .await
 455                    .trace_err()
 456                {
 457                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 458                    tracing::info!(
 459                        stale_channel_buffer_count = channel_ids.len(),
 460                        "retrieved stale channel buffers"
 461                    );
 462
 463                    for channel_id in channel_ids {
 464                        if let Some(refreshed_channel_buffer) = app_state
 465                            .db
 466                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 467                            .await
 468                            .trace_err()
 469                        {
 470                            for connection_id in refreshed_channel_buffer.connection_ids {
 471                                peer.send(
 472                                    connection_id,
 473                                    proto::UpdateChannelBufferCollaborators {
 474                                        channel_id: channel_id.to_proto(),
 475                                        collaborators: refreshed_channel_buffer
 476                                            .collaborators
 477                                            .clone(),
 478                                    },
 479                                )
 480                                .trace_err();
 481                            }
 482                        }
 483                    }
 484
 485                    for room_id in room_ids {
 486                        let mut contacts_to_update = HashSet::default();
 487                        let mut canceled_calls_to_user_ids = Vec::new();
 488                        let mut livekit_room = String::new();
 489                        let mut delete_livekit_room = false;
 490
 491                        if let Some(mut refreshed_room) = app_state
 492                            .db
 493                            .clear_stale_room_participants(room_id, server_id)
 494                            .await
 495                            .trace_err()
 496                        {
 497                            tracing::info!(
 498                                room_id = room_id.0,
 499                                new_participant_count = refreshed_room.room.participants.len(),
 500                                "refreshed room"
 501                            );
 502                            room_updated(&refreshed_room.room, &peer);
 503                            if let Some(channel) = refreshed_room.channel.as_ref() {
 504                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 505                            }
 506                            contacts_to_update
 507                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 508                            contacts_to_update
 509                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 510                            canceled_calls_to_user_ids =
 511                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 512                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 513                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 514                        }
 515
 516                        {
 517                            let pool = pool.lock();
 518                            for canceled_user_id in canceled_calls_to_user_ids {
 519                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 520                                    peer.send(
 521                                        connection_id,
 522                                        proto::CallCanceled {
 523                                            room_id: room_id.to_proto(),
 524                                        },
 525                                    )
 526                                    .trace_err();
 527                                }
 528                            }
 529                        }
 530
 531                        for user_id in contacts_to_update {
 532                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 533                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 534                            if let Some((busy, contacts)) = busy.zip(contacts) {
 535                                let pool = pool.lock();
 536                                let updated_contact = contact_for_user(user_id, busy, &pool);
 537                                for contact in contacts {
 538                                    if let db::Contact::Accepted {
 539                                        user_id: contact_user_id,
 540                                        ..
 541                                    } = contact
 542                                    {
 543                                        for contact_conn_id in
 544                                            pool.user_connection_ids(contact_user_id)
 545                                        {
 546                                            peer.send(
 547                                                contact_conn_id,
 548                                                proto::UpdateContacts {
 549                                                    contacts: vec![updated_contact.clone()],
 550                                                    remove_contacts: Default::default(),
 551                                                    incoming_requests: Default::default(),
 552                                                    remove_incoming_requests: Default::default(),
 553                                                    outgoing_requests: Default::default(),
 554                                                    remove_outgoing_requests: Default::default(),
 555                                                },
 556                                            )
 557                                            .trace_err();
 558                                        }
 559                                    }
 560                                }
 561                            }
 562                        }
 563
 564                        if let Some(live_kit) = livekit_client.as_ref() {
 565                            if delete_livekit_room {
 566                                live_kit.delete_room(livekit_room).await.trace_err();
 567                            }
 568                        }
 569                    }
 570                }
 571
 572                app_state
 573                    .db
 574                    .delete_stale_channel_chat_participants(
 575                        &app_state.config.zed_environment,
 576                        server_id,
 577                    )
 578                    .await
 579                    .trace_err();
 580
 581                app_state
 582                    .db
 583                    .clear_old_worktree_entries(server_id)
 584                    .await
 585                    .trace_err();
 586
 587                app_state
 588                    .db
 589                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 590                    .await
 591                    .trace_err();
 592            }
 593            .instrument(span),
 594        );
 595        Ok(())
 596    }
 597
 598    pub fn teardown(&self) {
 599        self.peer.teardown();
 600        self.connection_pool.lock().reset();
 601        let _ = self.teardown.send(true);
 602    }
 603
 604    #[cfg(test)]
 605    pub fn reset(&self, id: ServerId) {
 606        self.teardown();
 607        *self.id.lock() = id;
 608        self.peer.reset(id.0 as u32);
 609        let _ = self.teardown.send(false);
 610    }
 611
 612    #[cfg(test)]
 613    pub fn id(&self) -> ServerId {
 614        *self.id.lock()
 615    }
 616
 617    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 618    where
 619        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 620        Fut: 'static + Send + Future<Output = Result<()>>,
 621        M: EnvelopedMessage,
 622    {
 623        let prev_handler = self.handlers.insert(
 624            TypeId::of::<M>(),
 625            Box::new(move |envelope, session| {
 626                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 627                let received_at = envelope.received_at;
 628                tracing::info!("message received");
 629                let start_time = Instant::now();
 630                let future = (handler)(*envelope, session);
 631                async move {
 632                    let result = future.await;
 633                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 634                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 635                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 636                    let payload_type = M::NAME;
 637
 638                    match result {
 639                        Err(error) => {
 640                            tracing::error!(
 641                                ?error,
 642                                total_duration_ms,
 643                                processing_duration_ms,
 644                                queue_duration_ms,
 645                                payload_type,
 646                                "error handling message"
 647                            )
 648                        }
 649                        Ok(()) => tracing::info!(
 650                            total_duration_ms,
 651                            processing_duration_ms,
 652                            queue_duration_ms,
 653                            "finished handling message"
 654                        ),
 655                    }
 656                }
 657                .boxed()
 658            }),
 659        );
 660        if prev_handler.is_some() {
 661            panic!("registered a handler for the same message twice");
 662        }
 663        self
 664    }
 665
 666    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 667    where
 668        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 669        Fut: 'static + Send + Future<Output = Result<()>>,
 670        M: EnvelopedMessage,
 671    {
 672        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 673        self
 674    }
 675
 676    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 677    where
 678        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 679        Fut: Send + Future<Output = Result<()>>,
 680        M: RequestMessage,
 681    {
 682        let handler = Arc::new(handler);
 683        self.add_handler(move |envelope, session| {
 684            let receipt = envelope.receipt();
 685            let handler = handler.clone();
 686            async move {
 687                let peer = session.peer.clone();
 688                let responded = Arc::new(AtomicBool::default());
 689                let response = Response {
 690                    peer: peer.clone(),
 691                    responded: responded.clone(),
 692                    receipt,
 693                };
 694                match (handler)(envelope.payload, response, session).await {
 695                    Ok(()) => {
 696                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 697                            Ok(())
 698                        } else {
 699                            Err(anyhow!("handler did not send a response"))?
 700                        }
 701                    }
 702                    Err(error) => {
 703                        let proto_err = match &error {
 704                            Error::Internal(err) => err.to_proto(),
 705                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 706                        };
 707                        peer.respond_with_error(receipt, proto_err)?;
 708                        Err(error)
 709                    }
 710                }
 711            }
 712        })
 713    }
 714
 715    pub fn handle_connection(
 716        self: &Arc<Self>,
 717        connection: Connection,
 718        address: String,
 719        principal: Principal,
 720        zed_version: ZedVersion,
 721        geoip_country_code: Option<String>,
 722        system_id: Option<String>,
 723        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 724        executor: Executor,
 725    ) -> impl Future<Output = ()> + use<> {
 726        let this = self.clone();
 727        let span = info_span!("handle connection", %address,
 728            connection_id=field::Empty,
 729            user_id=field::Empty,
 730            login=field::Empty,
 731            impersonator=field::Empty,
 732            geoip_country_code=field::Empty
 733        );
 734        principal.update_span(&span);
 735        if let Some(country_code) = geoip_country_code.as_ref() {
 736            span.record("geoip_country_code", country_code);
 737        }
 738
 739        let mut teardown = self.teardown.subscribe();
 740        async move {
 741            if *teardown.borrow() {
 742                tracing::error!("server is tearing down");
 743                return
 744            }
 745            let (connection_id, handle_io, mut incoming_rx) = this
 746                .peer
 747                .add_connection(connection, {
 748                    let executor = executor.clone();
 749                    move |duration| executor.sleep(duration)
 750                });
 751            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 752
 753            tracing::info!("connection opened");
 754
 755            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 756            let http_client = match ReqwestClient::user_agent(&user_agent) {
 757                Ok(http_client) => Arc::new(http_client),
 758                Err(error) => {
 759                    tracing::error!(?error, "failed to create HTTP client");
 760                    return;
 761                }
 762            };
 763
 764            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 765                    supermaven_admin_api_key.to_string(),
 766                    http_client.clone(),
 767                )));
 768
 769            let session = Session {
 770                principal: principal.clone(),
 771                connection_id,
 772                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 773                peer: this.peer.clone(),
 774                connection_pool: this.connection_pool.clone(),
 775                app_state: this.app_state.clone(),
 776                geoip_country_code,
 777                system_id,
 778                _executor: executor.clone(),
 779                supermaven_client,
 780            };
 781
 782            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 783                tracing::error!(?error, "failed to send initial client update");
 784                return;
 785            }
 786
 787            let handle_io = handle_io.fuse();
 788            futures::pin_mut!(handle_io);
 789
 790            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 791            // This prevents deadlocks when e.g., client A performs a request to client B and
 792            // client B performs a request to client A. If both clients stop processing further
 793            // messages until their respective request completes, they won't have a chance to
 794            // respond to the other client's request and cause a deadlock.
 795            //
 796            // This arrangement ensures we will attempt to process earlier messages first, but fall
 797            // back to processing messages arrived later in the spirit of making progress.
 798            let mut foreground_message_handlers = FuturesUnordered::new();
 799            let concurrent_handlers = Arc::new(Semaphore::new(256));
 800            loop {
 801                let next_message = async {
 802                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 803                    let message = incoming_rx.next().await;
 804                    (permit, message)
 805                }.fuse();
 806                futures::pin_mut!(next_message);
 807                futures::select_biased! {
 808                    _ = teardown.changed().fuse() => return,
 809                    result = handle_io => {
 810                        if let Err(error) = result {
 811                            tracing::error!(?error, "error handling I/O");
 812                        }
 813                        break;
 814                    }
 815                    _ = foreground_message_handlers.next() => {}
 816                    next_message = next_message => {
 817                        let (permit, message) = next_message;
 818                        if let Some(message) = message {
 819                            let type_name = message.payload_type_name();
 820                            // note: we copy all the fields from the parent span so we can query them in the logs.
 821                            // (https://github.com/tokio-rs/tracing/issues/2670).
 822                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 823                                user_id=field::Empty,
 824                                login=field::Empty,
 825                                impersonator=field::Empty,
 826                            );
 827                            principal.update_span(&span);
 828                            let span_enter = span.enter();
 829                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 830                                let is_background = message.is_background();
 831                                let handle_message = (handler)(message, session.clone());
 832                                drop(span_enter);
 833
 834                                let handle_message = async move {
 835                                    handle_message.await;
 836                                    drop(permit);
 837                                }.instrument(span);
 838                                if is_background {
 839                                    executor.spawn_detached(handle_message);
 840                                } else {
 841                                    foreground_message_handlers.push(handle_message);
 842                                }
 843                            } else {
 844                                tracing::error!("no message handler");
 845                            }
 846                        } else {
 847                            tracing::info!("connection closed");
 848                            break;
 849                        }
 850                    }
 851                }
 852            }
 853
 854            drop(foreground_message_handlers);
 855            tracing::info!("signing out");
 856            if let Err(error) = connection_lost(session, teardown, executor).await {
 857                tracing::error!(?error, "error signing out");
 858            }
 859
 860        }.instrument(span)
 861    }
 862
 863    async fn send_initial_client_update(
 864        &self,
 865        connection_id: ConnectionId,
 866        zed_version: ZedVersion,
 867        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 868        session: &Session,
 869    ) -> Result<()> {
 870        self.peer.send(
 871            connection_id,
 872            proto::Hello {
 873                peer_id: Some(connection_id.into()),
 874            },
 875        )?;
 876        tracing::info!("sent hello message");
 877        if let Some(send_connection_id) = send_connection_id.take() {
 878            let _ = send_connection_id.send(connection_id);
 879        }
 880
 881        match &session.principal {
 882            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 883                if !user.connected_once {
 884                    self.peer.send(connection_id, proto::ShowContacts {})?;
 885                    self.app_state
 886                        .db
 887                        .set_user_connected_once(user.id, true)
 888                        .await?;
 889                }
 890
 891                update_user_plan(session).await?;
 892
 893                let contacts = self.app_state.db.get_contacts(user.id).await?;
 894
 895                {
 896                    let mut pool = self.connection_pool.lock();
 897                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 898                    self.peer.send(
 899                        connection_id,
 900                        build_initial_contacts_update(contacts, &pool),
 901                    )?;
 902                }
 903
 904                if should_auto_subscribe_to_channels(zed_version) {
 905                    subscribe_user_to_channels(user.id, session).await?;
 906                }
 907
 908                if let Some(incoming_call) =
 909                    self.app_state.db.incoming_call_for_user(user.id).await?
 910                {
 911                    self.peer.send(connection_id, incoming_call)?;
 912                }
 913
 914                update_user_contacts(user.id, session).await?;
 915            }
 916        }
 917
 918        Ok(())
 919    }
 920
 921    pub async fn invite_code_redeemed(
 922        self: &Arc<Self>,
 923        inviter_id: UserId,
 924        invitee_id: UserId,
 925    ) -> Result<()> {
 926        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 927            if let Some(code) = &user.invite_code {
 928                let pool = self.connection_pool.lock();
 929                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 930                for connection_id in pool.user_connection_ids(inviter_id) {
 931                    self.peer.send(
 932                        connection_id,
 933                        proto::UpdateContacts {
 934                            contacts: vec![invitee_contact.clone()],
 935                            ..Default::default()
 936                        },
 937                    )?;
 938                    self.peer.send(
 939                        connection_id,
 940                        proto::UpdateInviteInfo {
 941                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 942                            count: user.invite_count as u32,
 943                        },
 944                    )?;
 945                }
 946            }
 947        }
 948        Ok(())
 949    }
 950
 951    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 952        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 953            if let Some(invite_code) = &user.invite_code {
 954                let pool = self.connection_pool.lock();
 955                for connection_id in pool.user_connection_ids(user_id) {
 956                    self.peer.send(
 957                        connection_id,
 958                        proto::UpdateInviteInfo {
 959                            url: format!(
 960                                "{}{}",
 961                                self.app_state.config.invite_link_prefix, invite_code
 962                            ),
 963                            count: user.invite_count as u32,
 964                        },
 965                    )?;
 966                }
 967            }
 968        }
 969        Ok(())
 970    }
 971
 972    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 973        let user = self
 974            .app_state
 975            .db
 976            .get_user_by_id(user_id)
 977            .await?
 978            .context("user not found")?;
 979
 980        let update_user_plan = make_update_user_plan_message(
 981            &user,
 982            user.admin,
 983            &self.app_state.db,
 984            self.app_state.llm_db.clone(),
 985        )
 986        .await?;
 987
 988        let pool = self.connection_pool.lock();
 989        for connection_id in pool.user_connection_ids(user_id) {
 990            self.peer
 991                .send(connection_id, update_user_plan.clone())
 992                .trace_err();
 993        }
 994
 995        Ok(())
 996    }
 997
 998    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 999        let pool = self.connection_pool.lock();
1000        for connection_id in pool.user_connection_ids(user_id) {
1001            self.peer
1002                .send(connection_id, proto::RefreshLlmToken {})
1003                .trace_err();
1004        }
1005    }
1006
1007    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1008        ServerSnapshot {
1009            connection_pool: ConnectionPoolGuard {
1010                guard: self.connection_pool.lock(),
1011                _not_send: PhantomData,
1012            },
1013            peer: &self.peer,
1014        }
1015    }
1016}
1017
1018impl Deref for ConnectionPoolGuard<'_> {
1019    type Target = ConnectionPool;
1020
1021    fn deref(&self) -> &Self::Target {
1022        &self.guard
1023    }
1024}
1025
1026impl DerefMut for ConnectionPoolGuard<'_> {
1027    fn deref_mut(&mut self) -> &mut Self::Target {
1028        &mut self.guard
1029    }
1030}
1031
1032impl Drop for ConnectionPoolGuard<'_> {
1033    fn drop(&mut self) {
1034        #[cfg(test)]
1035        self.check_invariants();
1036    }
1037}
1038
1039fn broadcast<F>(
1040    sender_id: Option<ConnectionId>,
1041    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1042    mut f: F,
1043) where
1044    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1045{
1046    for receiver_id in receiver_ids {
1047        if Some(receiver_id) != sender_id {
1048            if let Err(error) = f(receiver_id) {
1049                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1050            }
1051        }
1052    }
1053}
1054
1055pub struct ProtocolVersion(u32);
1056
1057impl Header for ProtocolVersion {
1058    fn name() -> &'static HeaderName {
1059        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1060        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1061    }
1062
1063    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064    where
1065        Self: Sized,
1066        I: Iterator<Item = &'i axum::http::HeaderValue>,
1067    {
1068        let version = values
1069            .next()
1070            .ok_or_else(axum::headers::Error::invalid)?
1071            .to_str()
1072            .map_err(|_| axum::headers::Error::invalid())?
1073            .parse()
1074            .map_err(|_| axum::headers::Error::invalid())?;
1075        Ok(Self(version))
1076    }
1077
1078    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079        values.extend([self.0.to_string().parse().unwrap()]);
1080    }
1081}
1082
1083pub struct AppVersionHeader(SemanticVersion);
1084impl Header for AppVersionHeader {
1085    fn name() -> &'static HeaderName {
1086        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1087        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1088    }
1089
1090    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1091    where
1092        Self: Sized,
1093        I: Iterator<Item = &'i axum::http::HeaderValue>,
1094    {
1095        let version = values
1096            .next()
1097            .ok_or_else(axum::headers::Error::invalid)?
1098            .to_str()
1099            .map_err(|_| axum::headers::Error::invalid())?
1100            .parse()
1101            .map_err(|_| axum::headers::Error::invalid())?;
1102        Ok(Self(version))
1103    }
1104
1105    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1106        values.extend([self.0.to_string().parse().unwrap()]);
1107    }
1108}
1109
1110pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1111    Router::new()
1112        .route("/rpc", get(handle_websocket_request))
1113        .layer(
1114            ServiceBuilder::new()
1115                .layer(Extension(server.app_state.clone()))
1116                .layer(middleware::from_fn(auth::validate_header)),
1117        )
1118        .route("/metrics", get(handle_metrics))
1119        .layer(Extension(server))
1120}
1121
1122pub async fn handle_websocket_request(
1123    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1124    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1125    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1126    Extension(server): Extension<Arc<Server>>,
1127    Extension(principal): Extension<Principal>,
1128    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1129    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1130    ws: WebSocketUpgrade,
1131) -> axum::response::Response {
1132    if protocol_version != rpc::PROTOCOL_VERSION {
1133        return (
1134            StatusCode::UPGRADE_REQUIRED,
1135            "client must be upgraded".to_string(),
1136        )
1137            .into_response();
1138    }
1139
1140    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1141        return (
1142            StatusCode::UPGRADE_REQUIRED,
1143            "no version header found".to_string(),
1144        )
1145            .into_response();
1146    };
1147
1148    if !version.can_collaborate() {
1149        return (
1150            StatusCode::UPGRADE_REQUIRED,
1151            "client must be upgraded".to_string(),
1152        )
1153            .into_response();
1154    }
1155
1156    let socket_address = socket_address.to_string();
1157    ws.on_upgrade(move |socket| {
1158        let socket = socket
1159            .map_ok(to_tungstenite_message)
1160            .err_into()
1161            .with(|message| async move { to_axum_message(message) });
1162        let connection = Connection::new(Box::pin(socket));
1163        async move {
1164            server
1165                .handle_connection(
1166                    connection,
1167                    socket_address,
1168                    principal,
1169                    version,
1170                    country_code_header.map(|header| header.to_string()),
1171                    system_id_header.map(|header| header.to_string()),
1172                    None,
1173                    Executor::Production,
1174                )
1175                .await;
1176        }
1177    })
1178}
1179
1180pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1181    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1182    let connections_metric = CONNECTIONS_METRIC
1183        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1184
1185    let connections = server
1186        .connection_pool
1187        .lock()
1188        .connections()
1189        .filter(|connection| !connection.admin)
1190        .count();
1191    connections_metric.set(connections as _);
1192
1193    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1194    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1195        register_int_gauge!(
1196            "shared_projects",
1197            "number of open projects with one or more guests"
1198        )
1199        .unwrap()
1200    });
1201
1202    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1203    shared_projects_metric.set(shared_projects as _);
1204
1205    let encoder = prometheus::TextEncoder::new();
1206    let metric_families = prometheus::gather();
1207    let encoded_metrics = encoder
1208        .encode_to_string(&metric_families)
1209        .map_err(|err| anyhow!("{err}"))?;
1210    Ok(encoded_metrics)
1211}
1212
1213#[instrument(err, skip(executor))]
1214async fn connection_lost(
1215    session: Session,
1216    mut teardown: watch::Receiver<bool>,
1217    executor: Executor,
1218) -> Result<()> {
1219    session.peer.disconnect(session.connection_id);
1220    session
1221        .connection_pool()
1222        .await
1223        .remove_connection(session.connection_id)?;
1224
1225    session
1226        .db()
1227        .await
1228        .connection_lost(session.connection_id)
1229        .await
1230        .trace_err();
1231
1232    futures::select_biased! {
1233        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1234
1235            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1236            leave_room_for_session(&session, session.connection_id).await.trace_err();
1237            leave_channel_buffers_for_session(&session)
1238                .await
1239                .trace_err();
1240
1241            if !session
1242                .connection_pool()
1243                .await
1244                .is_user_online(session.user_id())
1245            {
1246                let db = session.db().await;
1247                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1248                    room_updated(&room, &session.peer);
1249                }
1250            }
1251
1252            update_user_contacts(session.user_id(), &session).await?;
1253        },
1254        _ = teardown.changed().fuse() => {}
1255    }
1256
1257    Ok(())
1258}
1259
1260/// Acknowledges a ping from a client, used to keep the connection alive.
1261async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1262    response.send(proto::Ack {})?;
1263    Ok(())
1264}
1265
1266/// Creates a new room for calling (outside of channels)
1267async fn create_room(
1268    _request: proto::CreateRoom,
1269    response: Response<proto::CreateRoom>,
1270    session: Session,
1271) -> Result<()> {
1272    let livekit_room = nanoid::nanoid!(30);
1273
1274    let live_kit_connection_info = util::maybe!(async {
1275        let live_kit = session.app_state.livekit_client.as_ref();
1276        let live_kit = live_kit?;
1277        let user_id = session.user_id().to_string();
1278
1279        let token = live_kit
1280            .room_token(&livekit_room, &user_id.to_string())
1281            .trace_err()?;
1282
1283        Some(proto::LiveKitConnectionInfo {
1284            server_url: live_kit.url().into(),
1285            token,
1286            can_publish: true,
1287        })
1288    })
1289    .await;
1290
1291    let room = session
1292        .db()
1293        .await
1294        .create_room(session.user_id(), session.connection_id, &livekit_room)
1295        .await?;
1296
1297    response.send(proto::CreateRoomResponse {
1298        room: Some(room.clone()),
1299        live_kit_connection_info,
1300    })?;
1301
1302    update_user_contacts(session.user_id(), &session).await?;
1303    Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308    request: proto::JoinRoom,
1309    response: Response<proto::JoinRoom>,
1310    session: Session,
1311) -> Result<()> {
1312    let room_id = RoomId::from_proto(request.id);
1313
1314    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316    if let Some(channel_id) = channel_id {
1317        return join_channel_internal(channel_id, Box::new(response), session).await;
1318    }
1319
1320    let joined_room = {
1321        let room = session
1322            .db()
1323            .await
1324            .join_room(room_id, session.user_id(), session.connection_id)
1325            .await?;
1326        room_updated(&room.room, &session.peer);
1327        room.into_inner()
1328    };
1329
1330    for connection_id in session
1331        .connection_pool()
1332        .await
1333        .user_connection_ids(session.user_id())
1334    {
1335        session
1336            .peer
1337            .send(
1338                connection_id,
1339                proto::CallCanceled {
1340                    room_id: room_id.to_proto(),
1341                },
1342            )
1343            .trace_err();
1344    }
1345
1346    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347    {
1348        live_kit
1349            .room_token(
1350                &joined_room.room.livekit_room,
1351                &session.user_id().to_string(),
1352            )
1353            .trace_err()
1354            .map(|token| proto::LiveKitConnectionInfo {
1355                server_url: live_kit.url().into(),
1356                token,
1357                can_publish: true,
1358            })
1359    } else {
1360        None
1361    };
1362
1363    response.send(proto::JoinRoomResponse {
1364        room: Some(joined_room.room),
1365        channel_id: None,
1366        live_kit_connection_info,
1367    })?;
1368
1369    update_user_contacts(session.user_id(), &session).await?;
1370    Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375    request: proto::RejoinRoom,
1376    response: Response<proto::RejoinRoom>,
1377    session: Session,
1378) -> Result<()> {
1379    let room;
1380    let channel;
1381    {
1382        let mut rejoined_room = session
1383            .db()
1384            .await
1385            .rejoin_room(request, session.user_id(), session.connection_id)
1386            .await?;
1387
1388        response.send(proto::RejoinRoomResponse {
1389            room: Some(rejoined_room.room.clone()),
1390            reshared_projects: rejoined_room
1391                .reshared_projects
1392                .iter()
1393                .map(|project| proto::ResharedProject {
1394                    id: project.id.to_proto(),
1395                    collaborators: project
1396                        .collaborators
1397                        .iter()
1398                        .map(|collaborator| collaborator.to_proto())
1399                        .collect(),
1400                })
1401                .collect(),
1402            rejoined_projects: rejoined_room
1403                .rejoined_projects
1404                .iter()
1405                .map(|rejoined_project| rejoined_project.to_proto())
1406                .collect(),
1407        })?;
1408        room_updated(&rejoined_room.room, &session.peer);
1409
1410        for project in &rejoined_room.reshared_projects {
1411            for collaborator in &project.collaborators {
1412                session
1413                    .peer
1414                    .send(
1415                        collaborator.connection_id,
1416                        proto::UpdateProjectCollaborator {
1417                            project_id: project.id.to_proto(),
1418                            old_peer_id: Some(project.old_connection_id.into()),
1419                            new_peer_id: Some(session.connection_id.into()),
1420                        },
1421                    )
1422                    .trace_err();
1423            }
1424
1425            broadcast(
1426                Some(session.connection_id),
1427                project
1428                    .collaborators
1429                    .iter()
1430                    .map(|collaborator| collaborator.connection_id),
1431                |connection_id| {
1432                    session.peer.forward_send(
1433                        session.connection_id,
1434                        connection_id,
1435                        proto::UpdateProject {
1436                            project_id: project.id.to_proto(),
1437                            worktrees: project.worktrees.clone(),
1438                        },
1439                    )
1440                },
1441            );
1442        }
1443
1444        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446        let rejoined_room = rejoined_room.into_inner();
1447
1448        room = rejoined_room.room;
1449        channel = rejoined_room.channel;
1450    }
1451
1452    if let Some(channel) = channel {
1453        channel_updated(
1454            &channel,
1455            &room,
1456            &session.peer,
1457            &*session.connection_pool().await,
1458        );
1459    }
1460
1461    update_user_contacts(session.user_id(), &session).await?;
1462    Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466    rejoined_projects: &mut Vec<RejoinedProject>,
1467    session: &Session,
1468) -> Result<()> {
1469    for project in rejoined_projects.iter() {
1470        for collaborator in &project.collaborators {
1471            session
1472                .peer
1473                .send(
1474                    collaborator.connection_id,
1475                    proto::UpdateProjectCollaborator {
1476                        project_id: project.id.to_proto(),
1477                        old_peer_id: Some(project.old_connection_id.into()),
1478                        new_peer_id: Some(session.connection_id.into()),
1479                    },
1480                )
1481                .trace_err();
1482        }
1483    }
1484
1485    for project in rejoined_projects {
1486        for worktree in mem::take(&mut project.worktrees) {
1487            // Stream this worktree's entries.
1488            let message = proto::UpdateWorktree {
1489                project_id: project.id.to_proto(),
1490                worktree_id: worktree.id,
1491                abs_path: worktree.abs_path.clone(),
1492                root_name: worktree.root_name,
1493                updated_entries: worktree.updated_entries,
1494                removed_entries: worktree.removed_entries,
1495                scan_id: worktree.scan_id,
1496                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1497                updated_repositories: worktree.updated_repositories,
1498                removed_repositories: worktree.removed_repositories,
1499            };
1500            for update in proto::split_worktree_update(message) {
1501                session.peer.send(session.connection_id, update)?;
1502            }
1503
1504            // Stream this worktree's diagnostics.
1505            for summary in worktree.diagnostic_summaries {
1506                session.peer.send(
1507                    session.connection_id,
1508                    proto::UpdateDiagnosticSummary {
1509                        project_id: project.id.to_proto(),
1510                        worktree_id: worktree.id,
1511                        summary: Some(summary),
1512                    },
1513                )?;
1514            }
1515
1516            for settings_file in worktree.settings_files {
1517                session.peer.send(
1518                    session.connection_id,
1519                    proto::UpdateWorktreeSettings {
1520                        project_id: project.id.to_proto(),
1521                        worktree_id: worktree.id,
1522                        path: settings_file.path,
1523                        content: Some(settings_file.content),
1524                        kind: Some(settings_file.kind.to_proto().into()),
1525                    },
1526                )?;
1527            }
1528        }
1529
1530        for repository in mem::take(&mut project.updated_repositories) {
1531            for update in split_repository_update(repository) {
1532                session.peer.send(session.connection_id, update)?;
1533            }
1534        }
1535
1536        for id in mem::take(&mut project.removed_repositories) {
1537            session.peer.send(
1538                session.connection_id,
1539                proto::RemoveRepository {
1540                    project_id: project.id.to_proto(),
1541                    id,
1542                },
1543            )?;
1544        }
1545    }
1546
1547    Ok(())
1548}
1549
1550/// leave room disconnects from the room.
1551async fn leave_room(
1552    _: proto::LeaveRoom,
1553    response: Response<proto::LeaveRoom>,
1554    session: Session,
1555) -> Result<()> {
1556    leave_room_for_session(&session, session.connection_id).await?;
1557    response.send(proto::Ack {})?;
1558    Ok(())
1559}
1560
1561/// Updates the permissions of someone else in the room.
1562async fn set_room_participant_role(
1563    request: proto::SetRoomParticipantRole,
1564    response: Response<proto::SetRoomParticipantRole>,
1565    session: Session,
1566) -> Result<()> {
1567    let user_id = UserId::from_proto(request.user_id);
1568    let role = ChannelRole::from(request.role());
1569
1570    let (livekit_room, can_publish) = {
1571        let room = session
1572            .db()
1573            .await
1574            .set_room_participant_role(
1575                session.user_id(),
1576                RoomId::from_proto(request.room_id),
1577                user_id,
1578                role,
1579            )
1580            .await?;
1581
1582        let livekit_room = room.livekit_room.clone();
1583        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1584        room_updated(&room, &session.peer);
1585        (livekit_room, can_publish)
1586    };
1587
1588    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1589        live_kit
1590            .update_participant(
1591                livekit_room.clone(),
1592                request.user_id.to_string(),
1593                livekit_api::proto::ParticipantPermission {
1594                    can_subscribe: true,
1595                    can_publish,
1596                    can_publish_data: can_publish,
1597                    hidden: false,
1598                    recorder: false,
1599                },
1600            )
1601            .await
1602            .trace_err();
1603    }
1604
1605    response.send(proto::Ack {})?;
1606    Ok(())
1607}
1608
1609/// Call someone else into the current room
1610async fn call(
1611    request: proto::Call,
1612    response: Response<proto::Call>,
1613    session: Session,
1614) -> Result<()> {
1615    let room_id = RoomId::from_proto(request.room_id);
1616    let calling_user_id = session.user_id();
1617    let calling_connection_id = session.connection_id;
1618    let called_user_id = UserId::from_proto(request.called_user_id);
1619    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1620    if !session
1621        .db()
1622        .await
1623        .has_contact(calling_user_id, called_user_id)
1624        .await?
1625    {
1626        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1627    }
1628
1629    let incoming_call = {
1630        let (room, incoming_call) = &mut *session
1631            .db()
1632            .await
1633            .call(
1634                room_id,
1635                calling_user_id,
1636                calling_connection_id,
1637                called_user_id,
1638                initial_project_id,
1639            )
1640            .await?;
1641        room_updated(room, &session.peer);
1642        mem::take(incoming_call)
1643    };
1644    update_user_contacts(called_user_id, &session).await?;
1645
1646    let mut calls = session
1647        .connection_pool()
1648        .await
1649        .user_connection_ids(called_user_id)
1650        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1651        .collect::<FuturesUnordered<_>>();
1652
1653    while let Some(call_response) = calls.next().await {
1654        match call_response.as_ref() {
1655            Ok(_) => {
1656                response.send(proto::Ack {})?;
1657                return Ok(());
1658            }
1659            Err(_) => {
1660                call_response.trace_err();
1661            }
1662        }
1663    }
1664
1665    {
1666        let room = session
1667            .db()
1668            .await
1669            .call_failed(room_id, called_user_id)
1670            .await?;
1671        room_updated(&room, &session.peer);
1672    }
1673    update_user_contacts(called_user_id, &session).await?;
1674
1675    Err(anyhow!("failed to ring user"))?
1676}
1677
1678/// Cancel an outgoing call.
1679async fn cancel_call(
1680    request: proto::CancelCall,
1681    response: Response<proto::CancelCall>,
1682    session: Session,
1683) -> Result<()> {
1684    let called_user_id = UserId::from_proto(request.called_user_id);
1685    let room_id = RoomId::from_proto(request.room_id);
1686    {
1687        let room = session
1688            .db()
1689            .await
1690            .cancel_call(room_id, session.connection_id, called_user_id)
1691            .await?;
1692        room_updated(&room, &session.peer);
1693    }
1694
1695    for connection_id in session
1696        .connection_pool()
1697        .await
1698        .user_connection_ids(called_user_id)
1699    {
1700        session
1701            .peer
1702            .send(
1703                connection_id,
1704                proto::CallCanceled {
1705                    room_id: room_id.to_proto(),
1706                },
1707            )
1708            .trace_err();
1709    }
1710    response.send(proto::Ack {})?;
1711
1712    update_user_contacts(called_user_id, &session).await?;
1713    Ok(())
1714}
1715
1716/// Decline an incoming call.
1717async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1718    let room_id = RoomId::from_proto(message.room_id);
1719    {
1720        let room = session
1721            .db()
1722            .await
1723            .decline_call(Some(room_id), session.user_id())
1724            .await?
1725            .context("declining call")?;
1726        room_updated(&room, &session.peer);
1727    }
1728
1729    for connection_id in session
1730        .connection_pool()
1731        .await
1732        .user_connection_ids(session.user_id())
1733    {
1734        session
1735            .peer
1736            .send(
1737                connection_id,
1738                proto::CallCanceled {
1739                    room_id: room_id.to_proto(),
1740                },
1741            )
1742            .trace_err();
1743    }
1744    update_user_contacts(session.user_id(), &session).await?;
1745    Ok(())
1746}
1747
1748/// Updates other participants in the room with your current location.
1749async fn update_participant_location(
1750    request: proto::UpdateParticipantLocation,
1751    response: Response<proto::UpdateParticipantLocation>,
1752    session: Session,
1753) -> Result<()> {
1754    let room_id = RoomId::from_proto(request.room_id);
1755    let location = request.location.context("invalid location")?;
1756
1757    let db = session.db().await;
1758    let room = db
1759        .update_room_participant_location(room_id, session.connection_id, location)
1760        .await?;
1761
1762    room_updated(&room, &session.peer);
1763    response.send(proto::Ack {})?;
1764    Ok(())
1765}
1766
1767/// Share a project into the room.
1768async fn share_project(
1769    request: proto::ShareProject,
1770    response: Response<proto::ShareProject>,
1771    session: Session,
1772) -> Result<()> {
1773    let (project_id, room) = &*session
1774        .db()
1775        .await
1776        .share_project(
1777            RoomId::from_proto(request.room_id),
1778            session.connection_id,
1779            &request.worktrees,
1780            request.is_ssh_project,
1781        )
1782        .await?;
1783    response.send(proto::ShareProjectResponse {
1784        project_id: project_id.to_proto(),
1785    })?;
1786    room_updated(room, &session.peer);
1787
1788    Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1793    let project_id = ProjectId::from_proto(message.project_id);
1794    unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798    project_id: ProjectId,
1799    connection_id: ConnectionId,
1800    session: &Session,
1801) -> Result<()> {
1802    let delete = {
1803        let room_guard = session
1804            .db()
1805            .await
1806            .unshare_project(project_id, connection_id)
1807            .await?;
1808
1809        let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811        let message = proto::UnshareProject {
1812            project_id: project_id.to_proto(),
1813        };
1814
1815        broadcast(
1816            Some(connection_id),
1817            guest_connection_ids.iter().copied(),
1818            |conn_id| session.peer.send(conn_id, message.clone()),
1819        );
1820        if let Some(room) = room {
1821            room_updated(room, &session.peer);
1822        }
1823
1824        *delete
1825    };
1826
1827    if delete {
1828        let db = session.db().await;
1829        db.delete_project(project_id).await?;
1830    }
1831
1832    Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837    request: proto::JoinProject,
1838    response: Response<proto::JoinProject>,
1839    session: Session,
1840) -> Result<()> {
1841    let project_id = ProjectId::from_proto(request.project_id);
1842
1843    tracing::info!(%project_id, "join project");
1844
1845    let db = session.db().await;
1846    let (project, replica_id) = &mut *db
1847        .join_project(project_id, session.connection_id, session.user_id())
1848        .await?;
1849    drop(db);
1850    tracing::info!(%project_id, "join remote project");
1851    join_project_internal(response, session, project, replica_id)
1852}
1853
1854trait JoinProjectInternalResponse {
1855    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1856}
1857impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1858    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1859        Response::<proto::JoinProject>::send(self, result)
1860    }
1861}
1862
1863fn join_project_internal(
1864    response: impl JoinProjectInternalResponse,
1865    session: Session,
1866    project: &mut Project,
1867    replica_id: &ReplicaId,
1868) -> Result<()> {
1869    let collaborators = project
1870        .collaborators
1871        .iter()
1872        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1873        .map(|collaborator| collaborator.to_proto())
1874        .collect::<Vec<_>>();
1875    let project_id = project.id;
1876    let guest_user_id = session.user_id();
1877
1878    let worktrees = project
1879        .worktrees
1880        .iter()
1881        .map(|(id, worktree)| proto::WorktreeMetadata {
1882            id: *id,
1883            root_name: worktree.root_name.clone(),
1884            visible: worktree.visible,
1885            abs_path: worktree.abs_path.clone(),
1886        })
1887        .collect::<Vec<_>>();
1888
1889    let add_project_collaborator = proto::AddProjectCollaborator {
1890        project_id: project_id.to_proto(),
1891        collaborator: Some(proto::Collaborator {
1892            peer_id: Some(session.connection_id.into()),
1893            replica_id: replica_id.0 as u32,
1894            user_id: guest_user_id.to_proto(),
1895            is_host: false,
1896        }),
1897    };
1898
1899    for collaborator in &collaborators {
1900        session
1901            .peer
1902            .send(
1903                collaborator.peer_id.unwrap().into(),
1904                add_project_collaborator.clone(),
1905            )
1906            .trace_err();
1907    }
1908
1909    // First, we send the metadata associated with each worktree.
1910    response.send(proto::JoinProjectResponse {
1911        project_id: project.id.0 as u64,
1912        worktrees: worktrees.clone(),
1913        replica_id: replica_id.0 as u32,
1914        collaborators: collaborators.clone(),
1915        language_servers: project.language_servers.clone(),
1916        role: project.role.into(),
1917    })?;
1918
1919    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1920        // Stream this worktree's entries.
1921        let message = proto::UpdateWorktree {
1922            project_id: project_id.to_proto(),
1923            worktree_id,
1924            abs_path: worktree.abs_path.clone(),
1925            root_name: worktree.root_name,
1926            updated_entries: worktree.entries,
1927            removed_entries: Default::default(),
1928            scan_id: worktree.scan_id,
1929            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1930            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1931            removed_repositories: Default::default(),
1932        };
1933        for update in proto::split_worktree_update(message) {
1934            session.peer.send(session.connection_id, update.clone())?;
1935        }
1936
1937        // Stream this worktree's diagnostics.
1938        for summary in worktree.diagnostic_summaries {
1939            session.peer.send(
1940                session.connection_id,
1941                proto::UpdateDiagnosticSummary {
1942                    project_id: project_id.to_proto(),
1943                    worktree_id: worktree.id,
1944                    summary: Some(summary),
1945                },
1946            )?;
1947        }
1948
1949        for settings_file in worktree.settings_files {
1950            session.peer.send(
1951                session.connection_id,
1952                proto::UpdateWorktreeSettings {
1953                    project_id: project_id.to_proto(),
1954                    worktree_id: worktree.id,
1955                    path: settings_file.path,
1956                    content: Some(settings_file.content),
1957                    kind: Some(settings_file.kind.to_proto() as i32),
1958                },
1959            )?;
1960        }
1961    }
1962
1963    for repository in mem::take(&mut project.repositories) {
1964        for update in split_repository_update(repository) {
1965            session.peer.send(session.connection_id, update)?;
1966        }
1967    }
1968
1969    for language_server in &project.language_servers {
1970        session.peer.send(
1971            session.connection_id,
1972            proto::UpdateLanguageServer {
1973                project_id: project_id.to_proto(),
1974                language_server_id: language_server.id,
1975                variant: Some(
1976                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1977                        proto::LspDiskBasedDiagnosticsUpdated {},
1978                    ),
1979                ),
1980            },
1981        )?;
1982    }
1983
1984    Ok(())
1985}
1986
1987/// Leave someone elses shared project.
1988async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1989    let sender_id = session.connection_id;
1990    let project_id = ProjectId::from_proto(request.project_id);
1991    let db = session.db().await;
1992
1993    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1994    tracing::info!(
1995        %project_id,
1996        "leave project"
1997    );
1998
1999    project_left(project, &session);
2000    if let Some(room) = room {
2001        room_updated(room, &session.peer);
2002    }
2003
2004    Ok(())
2005}
2006
2007/// Updates other participants with changes to the project
2008async fn update_project(
2009    request: proto::UpdateProject,
2010    response: Response<proto::UpdateProject>,
2011    session: Session,
2012) -> Result<()> {
2013    let project_id = ProjectId::from_proto(request.project_id);
2014    let (room, guest_connection_ids) = &*session
2015        .db()
2016        .await
2017        .update_project(project_id, session.connection_id, &request.worktrees)
2018        .await?;
2019    broadcast(
2020        Some(session.connection_id),
2021        guest_connection_ids.iter().copied(),
2022        |connection_id| {
2023            session
2024                .peer
2025                .forward_send(session.connection_id, connection_id, request.clone())
2026        },
2027    );
2028    if let Some(room) = room {
2029        room_updated(room, &session.peer);
2030    }
2031    response.send(proto::Ack {})?;
2032
2033    Ok(())
2034}
2035
2036/// Updates other participants with changes to the worktree
2037async fn update_worktree(
2038    request: proto::UpdateWorktree,
2039    response: Response<proto::UpdateWorktree>,
2040    session: Session,
2041) -> Result<()> {
2042    let guest_connection_ids = session
2043        .db()
2044        .await
2045        .update_worktree(&request, session.connection_id)
2046        .await?;
2047
2048    broadcast(
2049        Some(session.connection_id),
2050        guest_connection_ids.iter().copied(),
2051        |connection_id| {
2052            session
2053                .peer
2054                .forward_send(session.connection_id, connection_id, request.clone())
2055        },
2056    );
2057    response.send(proto::Ack {})?;
2058    Ok(())
2059}
2060
2061async fn update_repository(
2062    request: proto::UpdateRepository,
2063    response: Response<proto::UpdateRepository>,
2064    session: Session,
2065) -> Result<()> {
2066    let guest_connection_ids = session
2067        .db()
2068        .await
2069        .update_repository(&request, session.connection_id)
2070        .await?;
2071
2072    broadcast(
2073        Some(session.connection_id),
2074        guest_connection_ids.iter().copied(),
2075        |connection_id| {
2076            session
2077                .peer
2078                .forward_send(session.connection_id, connection_id, request.clone())
2079        },
2080    );
2081    response.send(proto::Ack {})?;
2082    Ok(())
2083}
2084
2085async fn remove_repository(
2086    request: proto::RemoveRepository,
2087    response: Response<proto::RemoveRepository>,
2088    session: Session,
2089) -> Result<()> {
2090    let guest_connection_ids = session
2091        .db()
2092        .await
2093        .remove_repository(&request, session.connection_id)
2094        .await?;
2095
2096    broadcast(
2097        Some(session.connection_id),
2098        guest_connection_ids.iter().copied(),
2099        |connection_id| {
2100            session
2101                .peer
2102                .forward_send(session.connection_id, connection_id, request.clone())
2103        },
2104    );
2105    response.send(proto::Ack {})?;
2106    Ok(())
2107}
2108
2109/// Updates other participants with changes to the diagnostics
2110async fn update_diagnostic_summary(
2111    message: proto::UpdateDiagnosticSummary,
2112    session: Session,
2113) -> Result<()> {
2114    let guest_connection_ids = session
2115        .db()
2116        .await
2117        .update_diagnostic_summary(&message, session.connection_id)
2118        .await?;
2119
2120    broadcast(
2121        Some(session.connection_id),
2122        guest_connection_ids.iter().copied(),
2123        |connection_id| {
2124            session
2125                .peer
2126                .forward_send(session.connection_id, connection_id, message.clone())
2127        },
2128    );
2129
2130    Ok(())
2131}
2132
2133/// Updates other participants with changes to the worktree settings
2134async fn update_worktree_settings(
2135    message: proto::UpdateWorktreeSettings,
2136    session: Session,
2137) -> Result<()> {
2138    let guest_connection_ids = session
2139        .db()
2140        .await
2141        .update_worktree_settings(&message, session.connection_id)
2142        .await?;
2143
2144    broadcast(
2145        Some(session.connection_id),
2146        guest_connection_ids.iter().copied(),
2147        |connection_id| {
2148            session
2149                .peer
2150                .forward_send(session.connection_id, connection_id, message.clone())
2151        },
2152    );
2153
2154    Ok(())
2155}
2156
2157/// Notify other participants that a language server has started.
2158async fn start_language_server(
2159    request: proto::StartLanguageServer,
2160    session: Session,
2161) -> Result<()> {
2162    let guest_connection_ids = session
2163        .db()
2164        .await
2165        .start_language_server(&request, session.connection_id)
2166        .await?;
2167
2168    broadcast(
2169        Some(session.connection_id),
2170        guest_connection_ids.iter().copied(),
2171        |connection_id| {
2172            session
2173                .peer
2174                .forward_send(session.connection_id, connection_id, request.clone())
2175        },
2176    );
2177    Ok(())
2178}
2179
2180/// Notify other participants that a language server has changed.
2181async fn update_language_server(
2182    request: proto::UpdateLanguageServer,
2183    session: Session,
2184) -> Result<()> {
2185    let project_id = ProjectId::from_proto(request.project_id);
2186    let project_connection_ids = session
2187        .db()
2188        .await
2189        .project_connection_ids(project_id, session.connection_id, true)
2190        .await?;
2191    broadcast(
2192        Some(session.connection_id),
2193        project_connection_ids.iter().copied(),
2194        |connection_id| {
2195            session
2196                .peer
2197                .forward_send(session.connection_id, connection_id, request.clone())
2198        },
2199    );
2200    Ok(())
2201}
2202
2203/// forward a project request to the host. These requests should be read only
2204/// as guests are allowed to send them.
2205async fn forward_read_only_project_request<T>(
2206    request: T,
2207    response: Response<T>,
2208    session: Session,
2209) -> Result<()>
2210where
2211    T: EntityMessage + RequestMessage,
2212{
2213    let project_id = ProjectId::from_proto(request.remote_entity_id());
2214    let host_connection_id = session
2215        .db()
2216        .await
2217        .host_for_read_only_project_request(project_id, session.connection_id)
2218        .await?;
2219    let payload = session
2220        .peer
2221        .forward_request(session.connection_id, host_connection_id, request)
2222        .await?;
2223    response.send(payload)?;
2224    Ok(())
2225}
2226
2227async fn forward_find_search_candidates_request(
2228    request: proto::FindSearchCandidates,
2229    response: Response<proto::FindSearchCandidates>,
2230    session: Session,
2231) -> Result<()> {
2232    let project_id = ProjectId::from_proto(request.remote_entity_id());
2233    let host_connection_id = session
2234        .db()
2235        .await
2236        .host_for_read_only_project_request(project_id, session.connection_id)
2237        .await?;
2238    let payload = session
2239        .peer
2240        .forward_request(session.connection_id, host_connection_id, request)
2241        .await?;
2242    response.send(payload)?;
2243    Ok(())
2244}
2245
2246/// forward a project request to the host. These requests are disallowed
2247/// for guests.
2248async fn forward_mutating_project_request<T>(
2249    request: T,
2250    response: Response<T>,
2251    session: Session,
2252) -> Result<()>
2253where
2254    T: EntityMessage + RequestMessage,
2255{
2256    let project_id = ProjectId::from_proto(request.remote_entity_id());
2257
2258    let host_connection_id = session
2259        .db()
2260        .await
2261        .host_for_mutating_project_request(project_id, session.connection_id)
2262        .await?;
2263    let payload = session
2264        .peer
2265        .forward_request(session.connection_id, host_connection_id, request)
2266        .await?;
2267    response.send(payload)?;
2268    Ok(())
2269}
2270
2271/// Notify other participants that a new buffer has been created
2272async fn create_buffer_for_peer(
2273    request: proto::CreateBufferForPeer,
2274    session: Session,
2275) -> Result<()> {
2276    session
2277        .db()
2278        .await
2279        .check_user_is_project_host(
2280            ProjectId::from_proto(request.project_id),
2281            session.connection_id,
2282        )
2283        .await?;
2284    let peer_id = request.peer_id.context("invalid peer id")?;
2285    session
2286        .peer
2287        .forward_send(session.connection_id, peer_id.into(), request)?;
2288    Ok(())
2289}
2290
2291/// Notify other participants that a buffer has been updated. This is
2292/// allowed for guests as long as the update is limited to selections.
2293async fn update_buffer(
2294    request: proto::UpdateBuffer,
2295    response: Response<proto::UpdateBuffer>,
2296    session: Session,
2297) -> Result<()> {
2298    let project_id = ProjectId::from_proto(request.project_id);
2299    let mut capability = Capability::ReadOnly;
2300
2301    for op in request.operations.iter() {
2302        match op.variant {
2303            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2304            Some(_) => capability = Capability::ReadWrite,
2305        }
2306    }
2307
2308    let host = {
2309        let guard = session
2310            .db()
2311            .await
2312            .connections_for_buffer_update(project_id, session.connection_id, capability)
2313            .await?;
2314
2315        let (host, guests) = &*guard;
2316
2317        broadcast(
2318            Some(session.connection_id),
2319            guests.clone(),
2320            |connection_id| {
2321                session
2322                    .peer
2323                    .forward_send(session.connection_id, connection_id, request.clone())
2324            },
2325        );
2326
2327        *host
2328    };
2329
2330    if host != session.connection_id {
2331        session
2332            .peer
2333            .forward_request(session.connection_id, host, request.clone())
2334            .await?;
2335    }
2336
2337    response.send(proto::Ack {})?;
2338    Ok(())
2339}
2340
2341async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2342    let project_id = ProjectId::from_proto(message.project_id);
2343
2344    let operation = message.operation.as_ref().context("invalid operation")?;
2345    let capability = match operation.variant.as_ref() {
2346        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2347            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2348                match buffer_op.variant {
2349                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2350                        Capability::ReadOnly
2351                    }
2352                    _ => Capability::ReadWrite,
2353                }
2354            } else {
2355                Capability::ReadWrite
2356            }
2357        }
2358        Some(_) => Capability::ReadWrite,
2359        None => Capability::ReadOnly,
2360    };
2361
2362    let guard = session
2363        .db()
2364        .await
2365        .connections_for_buffer_update(project_id, session.connection_id, capability)
2366        .await?;
2367
2368    let (host, guests) = &*guard;
2369
2370    broadcast(
2371        Some(session.connection_id),
2372        guests.iter().chain([host]).copied(),
2373        |connection_id| {
2374            session
2375                .peer
2376                .forward_send(session.connection_id, connection_id, message.clone())
2377        },
2378    );
2379
2380    Ok(())
2381}
2382
2383/// Notify other participants that a project has been updated.
2384async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2385    request: T,
2386    session: Session,
2387) -> Result<()> {
2388    let project_id = ProjectId::from_proto(request.remote_entity_id());
2389    let project_connection_ids = session
2390        .db()
2391        .await
2392        .project_connection_ids(project_id, session.connection_id, false)
2393        .await?;
2394
2395    broadcast(
2396        Some(session.connection_id),
2397        project_connection_ids.iter().copied(),
2398        |connection_id| {
2399            session
2400                .peer
2401                .forward_send(session.connection_id, connection_id, request.clone())
2402        },
2403    );
2404    Ok(())
2405}
2406
2407/// Start following another user in a call.
2408async fn follow(
2409    request: proto::Follow,
2410    response: Response<proto::Follow>,
2411    session: Session,
2412) -> Result<()> {
2413    let room_id = RoomId::from_proto(request.room_id);
2414    let project_id = request.project_id.map(ProjectId::from_proto);
2415    let leader_id = request.leader_id.context("invalid leader id")?.into();
2416    let follower_id = session.connection_id;
2417
2418    session
2419        .db()
2420        .await
2421        .check_room_participants(room_id, leader_id, session.connection_id)
2422        .await?;
2423
2424    let response_payload = session
2425        .peer
2426        .forward_request(session.connection_id, leader_id, request)
2427        .await?;
2428    response.send(response_payload)?;
2429
2430    if let Some(project_id) = project_id {
2431        let room = session
2432            .db()
2433            .await
2434            .follow(room_id, project_id, leader_id, follower_id)
2435            .await?;
2436        room_updated(&room, &session.peer);
2437    }
2438
2439    Ok(())
2440}
2441
2442/// Stop following another user in a call.
2443async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2444    let room_id = RoomId::from_proto(request.room_id);
2445    let project_id = request.project_id.map(ProjectId::from_proto);
2446    let leader_id = request.leader_id.context("invalid leader id")?.into();
2447    let follower_id = session.connection_id;
2448
2449    session
2450        .db()
2451        .await
2452        .check_room_participants(room_id, leader_id, session.connection_id)
2453        .await?;
2454
2455    session
2456        .peer
2457        .forward_send(session.connection_id, leader_id, request)?;
2458
2459    if let Some(project_id) = project_id {
2460        let room = session
2461            .db()
2462            .await
2463            .unfollow(room_id, project_id, leader_id, follower_id)
2464            .await?;
2465        room_updated(&room, &session.peer);
2466    }
2467
2468    Ok(())
2469}
2470
2471/// Notify everyone following you of your current location.
2472async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2473    let room_id = RoomId::from_proto(request.room_id);
2474    let database = session.db.lock().await;
2475
2476    let connection_ids = if let Some(project_id) = request.project_id {
2477        let project_id = ProjectId::from_proto(project_id);
2478        database
2479            .project_connection_ids(project_id, session.connection_id, true)
2480            .await?
2481    } else {
2482        database
2483            .room_connection_ids(room_id, session.connection_id)
2484            .await?
2485    };
2486
2487    // For now, don't send view update messages back to that view's current leader.
2488    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2489        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2490        _ => None,
2491    });
2492
2493    for connection_id in connection_ids.iter().cloned() {
2494        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2495            session
2496                .peer
2497                .forward_send(session.connection_id, connection_id, request.clone())?;
2498        }
2499    }
2500    Ok(())
2501}
2502
2503/// Get public data about users.
2504async fn get_users(
2505    request: proto::GetUsers,
2506    response: Response<proto::GetUsers>,
2507    session: Session,
2508) -> Result<()> {
2509    let user_ids = request
2510        .user_ids
2511        .into_iter()
2512        .map(UserId::from_proto)
2513        .collect();
2514    let users = session
2515        .db()
2516        .await
2517        .get_users_by_ids(user_ids)
2518        .await?
2519        .into_iter()
2520        .map(|user| proto::User {
2521            id: user.id.to_proto(),
2522            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2523            github_login: user.github_login,
2524            email: user.email_address,
2525            name: user.name,
2526        })
2527        .collect();
2528    response.send(proto::UsersResponse { users })?;
2529    Ok(())
2530}
2531
2532/// Search for users (to invite) buy Github login
2533async fn fuzzy_search_users(
2534    request: proto::FuzzySearchUsers,
2535    response: Response<proto::FuzzySearchUsers>,
2536    session: Session,
2537) -> Result<()> {
2538    let query = request.query;
2539    let users = match query.len() {
2540        0 => vec![],
2541        1 | 2 => session
2542            .db()
2543            .await
2544            .get_user_by_github_login(&query)
2545            .await?
2546            .into_iter()
2547            .collect(),
2548        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2549    };
2550    let users = users
2551        .into_iter()
2552        .filter(|user| user.id != session.user_id())
2553        .map(|user| proto::User {
2554            id: user.id.to_proto(),
2555            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2556            github_login: user.github_login,
2557            name: user.name,
2558            email: user.email_address,
2559        })
2560        .collect();
2561    response.send(proto::UsersResponse { users })?;
2562    Ok(())
2563}
2564
2565/// Send a contact request to another user.
2566async fn request_contact(
2567    request: proto::RequestContact,
2568    response: Response<proto::RequestContact>,
2569    session: Session,
2570) -> Result<()> {
2571    let requester_id = session.user_id();
2572    let responder_id = UserId::from_proto(request.responder_id);
2573    if requester_id == responder_id {
2574        return Err(anyhow!("cannot add yourself as a contact"))?;
2575    }
2576
2577    let notifications = session
2578        .db()
2579        .await
2580        .send_contact_request(requester_id, responder_id)
2581        .await?;
2582
2583    // Update outgoing contact requests of requester
2584    let mut update = proto::UpdateContacts::default();
2585    update.outgoing_requests.push(responder_id.to_proto());
2586    for connection_id in session
2587        .connection_pool()
2588        .await
2589        .user_connection_ids(requester_id)
2590    {
2591        session.peer.send(connection_id, update.clone())?;
2592    }
2593
2594    // Update incoming contact requests of responder
2595    let mut update = proto::UpdateContacts::default();
2596    update
2597        .incoming_requests
2598        .push(proto::IncomingContactRequest {
2599            requester_id: requester_id.to_proto(),
2600        });
2601    let connection_pool = session.connection_pool().await;
2602    for connection_id in connection_pool.user_connection_ids(responder_id) {
2603        session.peer.send(connection_id, update.clone())?;
2604    }
2605
2606    send_notifications(&connection_pool, &session.peer, notifications);
2607
2608    response.send(proto::Ack {})?;
2609    Ok(())
2610}
2611
2612/// Accept or decline a contact request
2613async fn respond_to_contact_request(
2614    request: proto::RespondToContactRequest,
2615    response: Response<proto::RespondToContactRequest>,
2616    session: Session,
2617) -> Result<()> {
2618    let responder_id = session.user_id();
2619    let requester_id = UserId::from_proto(request.requester_id);
2620    let db = session.db().await;
2621    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2622        db.dismiss_contact_notification(responder_id, requester_id)
2623            .await?;
2624    } else {
2625        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2626
2627        let notifications = db
2628            .respond_to_contact_request(responder_id, requester_id, accept)
2629            .await?;
2630        let requester_busy = db.is_user_busy(requester_id).await?;
2631        let responder_busy = db.is_user_busy(responder_id).await?;
2632
2633        let pool = session.connection_pool().await;
2634        // Update responder with new contact
2635        let mut update = proto::UpdateContacts::default();
2636        if accept {
2637            update
2638                .contacts
2639                .push(contact_for_user(requester_id, requester_busy, &pool));
2640        }
2641        update
2642            .remove_incoming_requests
2643            .push(requester_id.to_proto());
2644        for connection_id in pool.user_connection_ids(responder_id) {
2645            session.peer.send(connection_id, update.clone())?;
2646        }
2647
2648        // Update requester with new contact
2649        let mut update = proto::UpdateContacts::default();
2650        if accept {
2651            update
2652                .contacts
2653                .push(contact_for_user(responder_id, responder_busy, &pool));
2654        }
2655        update
2656            .remove_outgoing_requests
2657            .push(responder_id.to_proto());
2658
2659        for connection_id in pool.user_connection_ids(requester_id) {
2660            session.peer.send(connection_id, update.clone())?;
2661        }
2662
2663        send_notifications(&pool, &session.peer, notifications);
2664    }
2665
2666    response.send(proto::Ack {})?;
2667    Ok(())
2668}
2669
2670/// Remove a contact.
2671async fn remove_contact(
2672    request: proto::RemoveContact,
2673    response: Response<proto::RemoveContact>,
2674    session: Session,
2675) -> Result<()> {
2676    let requester_id = session.user_id();
2677    let responder_id = UserId::from_proto(request.user_id);
2678    let db = session.db().await;
2679    let (contact_accepted, deleted_notification_id) =
2680        db.remove_contact(requester_id, responder_id).await?;
2681
2682    let pool = session.connection_pool().await;
2683    // Update outgoing contact requests of requester
2684    let mut update = proto::UpdateContacts::default();
2685    if contact_accepted {
2686        update.remove_contacts.push(responder_id.to_proto());
2687    } else {
2688        update
2689            .remove_outgoing_requests
2690            .push(responder_id.to_proto());
2691    }
2692    for connection_id in pool.user_connection_ids(requester_id) {
2693        session.peer.send(connection_id, update.clone())?;
2694    }
2695
2696    // Update incoming contact requests of responder
2697    let mut update = proto::UpdateContacts::default();
2698    if contact_accepted {
2699        update.remove_contacts.push(requester_id.to_proto());
2700    } else {
2701        update
2702            .remove_incoming_requests
2703            .push(requester_id.to_proto());
2704    }
2705    for connection_id in pool.user_connection_ids(responder_id) {
2706        session.peer.send(connection_id, update.clone())?;
2707        if let Some(notification_id) = deleted_notification_id {
2708            session.peer.send(
2709                connection_id,
2710                proto::DeleteNotification {
2711                    notification_id: notification_id.to_proto(),
2712                },
2713            )?;
2714        }
2715    }
2716
2717    response.send(proto::Ack {})?;
2718    Ok(())
2719}
2720
2721fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2722    version.0.minor() < 139
2723}
2724
2725async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2726    if is_staff {
2727        return Ok(proto::Plan::ZedPro);
2728    }
2729
2730    let subscription = db.get_active_billing_subscription(user_id).await?;
2731    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2732
2733    let plan = if let Some(subscription_kind) = subscription_kind {
2734        match subscription_kind {
2735            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2736            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2737            SubscriptionKind::ZedFree => proto::Plan::Free,
2738        }
2739    } else {
2740        proto::Plan::Free
2741    };
2742
2743    Ok(plan)
2744}
2745
2746async fn make_update_user_plan_message(
2747    user: &User,
2748    is_staff: bool,
2749    db: &Arc<Database>,
2750    llm_db: Option<Arc<LlmDatabase>>,
2751) -> Result<proto::UpdateUserPlan> {
2752    let feature_flags = db.get_user_flags(user.id).await?;
2753    let plan = current_plan(db, user.id, is_staff).await?;
2754    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2755    let billing_preferences = db.get_billing_preferences(user.id).await?;
2756
2757    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2758        let subscription = db.get_active_billing_subscription(user.id).await?;
2759
2760        let subscription_period =
2761            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2762
2763        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2764            llm_db
2765                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2766                .await?
2767        } else {
2768            None
2769        };
2770
2771        (subscription_period, usage)
2772    } else {
2773        (None, None)
2774    };
2775
2776    let account_too_young =
2777        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2778
2779    Ok(proto::UpdateUserPlan {
2780        plan: plan.into(),
2781        trial_started_at: billing_customer
2782            .as_ref()
2783            .and_then(|billing_customer| billing_customer.trial_started_at)
2784            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2785        is_usage_based_billing_enabled: if is_staff {
2786            Some(true)
2787        } else {
2788            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2789        },
2790        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2791            proto::SubscriptionPeriod {
2792                started_at: started_at.timestamp() as u64,
2793                ended_at: ended_at.timestamp() as u64,
2794            }
2795        }),
2796        account_too_young: Some(account_too_young),
2797        has_overdue_invoices: billing_customer
2798            .map(|billing_customer| billing_customer.has_overdue_invoices),
2799        usage: usage.map(|usage| {
2800            let plan = match plan {
2801                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2802                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2803                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2804            };
2805
2806            let model_requests_limit = match plan.model_requests_limit() {
2807                zed_llm_client::UsageLimit::Limited(limit) => {
2808                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2809                        && feature_flags
2810                            .iter()
2811                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2812                    {
2813                        1_000
2814                    } else {
2815                        limit
2816                    };
2817
2818                    zed_llm_client::UsageLimit::Limited(limit)
2819                }
2820                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2821            };
2822
2823            proto::SubscriptionUsage {
2824                model_requests_usage_amount: usage.model_requests as u32,
2825                model_requests_usage_limit: Some(proto::UsageLimit {
2826                    variant: Some(match model_requests_limit {
2827                        zed_llm_client::UsageLimit::Limited(limit) => {
2828                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2829                                limit: limit as u32,
2830                            })
2831                        }
2832                        zed_llm_client::UsageLimit::Unlimited => {
2833                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2834                        }
2835                    }),
2836                }),
2837                edit_predictions_usage_amount: usage.edit_predictions as u32,
2838                edit_predictions_usage_limit: Some(proto::UsageLimit {
2839                    variant: Some(match plan.edit_predictions_limit() {
2840                        zed_llm_client::UsageLimit::Limited(limit) => {
2841                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2842                                limit: limit as u32,
2843                            })
2844                        }
2845                        zed_llm_client::UsageLimit::Unlimited => {
2846                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2847                        }
2848                    }),
2849                }),
2850            }
2851        }),
2852    })
2853}
2854
2855async fn update_user_plan(session: &Session) -> Result<()> {
2856    let db = session.db().await;
2857
2858    let update_user_plan = make_update_user_plan_message(
2859        session.principal.user(),
2860        session.is_staff(),
2861        &db.0,
2862        session.app_state.llm_db.clone(),
2863    )
2864    .await?;
2865
2866    session
2867        .peer
2868        .send(session.connection_id, update_user_plan)
2869        .trace_err();
2870
2871    Ok(())
2872}
2873
2874async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2875    subscribe_user_to_channels(session.user_id(), &session).await?;
2876    Ok(())
2877}
2878
2879async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2880    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2881    let mut pool = session.connection_pool().await;
2882    for membership in &channels_for_user.channel_memberships {
2883        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2884    }
2885    session.peer.send(
2886        session.connection_id,
2887        build_update_user_channels(&channels_for_user),
2888    )?;
2889    session.peer.send(
2890        session.connection_id,
2891        build_channels_update(channels_for_user),
2892    )?;
2893    Ok(())
2894}
2895
2896/// Creates a new channel.
2897async fn create_channel(
2898    request: proto::CreateChannel,
2899    response: Response<proto::CreateChannel>,
2900    session: Session,
2901) -> Result<()> {
2902    let db = session.db().await;
2903
2904    let parent_id = request.parent_id.map(ChannelId::from_proto);
2905    let (channel, membership) = db
2906        .create_channel(&request.name, parent_id, session.user_id())
2907        .await?;
2908
2909    let root_id = channel.root_id();
2910    let channel = Channel::from_model(channel);
2911
2912    response.send(proto::CreateChannelResponse {
2913        channel: Some(channel.to_proto()),
2914        parent_id: request.parent_id,
2915    })?;
2916
2917    let mut connection_pool = session.connection_pool().await;
2918    if let Some(membership) = membership {
2919        connection_pool.subscribe_to_channel(
2920            membership.user_id,
2921            membership.channel_id,
2922            membership.role,
2923        );
2924        let update = proto::UpdateUserChannels {
2925            channel_memberships: vec![proto::ChannelMembership {
2926                channel_id: membership.channel_id.to_proto(),
2927                role: membership.role.into(),
2928            }],
2929            ..Default::default()
2930        };
2931        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2932            session.peer.send(connection_id, update.clone())?;
2933        }
2934    }
2935
2936    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2937        if !role.can_see_channel(channel.visibility) {
2938            continue;
2939        }
2940
2941        let update = proto::UpdateChannels {
2942            channels: vec![channel.to_proto()],
2943            ..Default::default()
2944        };
2945        session.peer.send(connection_id, update.clone())?;
2946    }
2947
2948    Ok(())
2949}
2950
2951/// Delete a channel
2952async fn delete_channel(
2953    request: proto::DeleteChannel,
2954    response: Response<proto::DeleteChannel>,
2955    session: Session,
2956) -> Result<()> {
2957    let db = session.db().await;
2958
2959    let channel_id = request.channel_id;
2960    let (root_channel, removed_channels) = db
2961        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2962        .await?;
2963    response.send(proto::Ack {})?;
2964
2965    // Notify members of removed channels
2966    let mut update = proto::UpdateChannels::default();
2967    update
2968        .delete_channels
2969        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2970
2971    let connection_pool = session.connection_pool().await;
2972    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2973        session.peer.send(connection_id, update.clone())?;
2974    }
2975
2976    Ok(())
2977}
2978
2979/// Invite someone to join a channel.
2980async fn invite_channel_member(
2981    request: proto::InviteChannelMember,
2982    response: Response<proto::InviteChannelMember>,
2983    session: Session,
2984) -> Result<()> {
2985    let db = session.db().await;
2986    let channel_id = ChannelId::from_proto(request.channel_id);
2987    let invitee_id = UserId::from_proto(request.user_id);
2988    let InviteMemberResult {
2989        channel,
2990        notifications,
2991    } = db
2992        .invite_channel_member(
2993            channel_id,
2994            invitee_id,
2995            session.user_id(),
2996            request.role().into(),
2997        )
2998        .await?;
2999
3000    let update = proto::UpdateChannels {
3001        channel_invitations: vec![channel.to_proto()],
3002        ..Default::default()
3003    };
3004
3005    let connection_pool = session.connection_pool().await;
3006    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3007        session.peer.send(connection_id, update.clone())?;
3008    }
3009
3010    send_notifications(&connection_pool, &session.peer, notifications);
3011
3012    response.send(proto::Ack {})?;
3013    Ok(())
3014}
3015
3016/// remove someone from a channel
3017async fn remove_channel_member(
3018    request: proto::RemoveChannelMember,
3019    response: Response<proto::RemoveChannelMember>,
3020    session: Session,
3021) -> Result<()> {
3022    let db = session.db().await;
3023    let channel_id = ChannelId::from_proto(request.channel_id);
3024    let member_id = UserId::from_proto(request.user_id);
3025
3026    let RemoveChannelMemberResult {
3027        membership_update,
3028        notification_id,
3029    } = db
3030        .remove_channel_member(channel_id, member_id, session.user_id())
3031        .await?;
3032
3033    let mut connection_pool = session.connection_pool().await;
3034    notify_membership_updated(
3035        &mut connection_pool,
3036        membership_update,
3037        member_id,
3038        &session.peer,
3039    );
3040    for connection_id in connection_pool.user_connection_ids(member_id) {
3041        if let Some(notification_id) = notification_id {
3042            session
3043                .peer
3044                .send(
3045                    connection_id,
3046                    proto::DeleteNotification {
3047                        notification_id: notification_id.to_proto(),
3048                    },
3049                )
3050                .trace_err();
3051        }
3052    }
3053
3054    response.send(proto::Ack {})?;
3055    Ok(())
3056}
3057
3058/// Toggle the channel between public and private.
3059/// Care is taken to maintain the invariant that public channels only descend from public channels,
3060/// (though members-only channels can appear at any point in the hierarchy).
3061async fn set_channel_visibility(
3062    request: proto::SetChannelVisibility,
3063    response: Response<proto::SetChannelVisibility>,
3064    session: Session,
3065) -> Result<()> {
3066    let db = session.db().await;
3067    let channel_id = ChannelId::from_proto(request.channel_id);
3068    let visibility = request.visibility().into();
3069
3070    let channel_model = db
3071        .set_channel_visibility(channel_id, visibility, session.user_id())
3072        .await?;
3073    let root_id = channel_model.root_id();
3074    let channel = Channel::from_model(channel_model);
3075
3076    let mut connection_pool = session.connection_pool().await;
3077    for (user_id, role) in connection_pool
3078        .channel_user_ids(root_id)
3079        .collect::<Vec<_>>()
3080        .into_iter()
3081    {
3082        let update = if role.can_see_channel(channel.visibility) {
3083            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3084            proto::UpdateChannels {
3085                channels: vec![channel.to_proto()],
3086                ..Default::default()
3087            }
3088        } else {
3089            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3090            proto::UpdateChannels {
3091                delete_channels: vec![channel.id.to_proto()],
3092                ..Default::default()
3093            }
3094        };
3095
3096        for connection_id in connection_pool.user_connection_ids(user_id) {
3097            session.peer.send(connection_id, update.clone())?;
3098        }
3099    }
3100
3101    response.send(proto::Ack {})?;
3102    Ok(())
3103}
3104
3105/// Alter the role for a user in the channel.
3106async fn set_channel_member_role(
3107    request: proto::SetChannelMemberRole,
3108    response: Response<proto::SetChannelMemberRole>,
3109    session: Session,
3110) -> Result<()> {
3111    let db = session.db().await;
3112    let channel_id = ChannelId::from_proto(request.channel_id);
3113    let member_id = UserId::from_proto(request.user_id);
3114    let result = db
3115        .set_channel_member_role(
3116            channel_id,
3117            session.user_id(),
3118            member_id,
3119            request.role().into(),
3120        )
3121        .await?;
3122
3123    match result {
3124        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3125            let mut connection_pool = session.connection_pool().await;
3126            notify_membership_updated(
3127                &mut connection_pool,
3128                membership_update,
3129                member_id,
3130                &session.peer,
3131            )
3132        }
3133        db::SetMemberRoleResult::InviteUpdated(channel) => {
3134            let update = proto::UpdateChannels {
3135                channel_invitations: vec![channel.to_proto()],
3136                ..Default::default()
3137            };
3138
3139            for connection_id in session
3140                .connection_pool()
3141                .await
3142                .user_connection_ids(member_id)
3143            {
3144                session.peer.send(connection_id, update.clone())?;
3145            }
3146        }
3147    }
3148
3149    response.send(proto::Ack {})?;
3150    Ok(())
3151}
3152
3153/// Change the name of a channel
3154async fn rename_channel(
3155    request: proto::RenameChannel,
3156    response: Response<proto::RenameChannel>,
3157    session: Session,
3158) -> Result<()> {
3159    let db = session.db().await;
3160    let channel_id = ChannelId::from_proto(request.channel_id);
3161    let channel_model = db
3162        .rename_channel(channel_id, session.user_id(), &request.name)
3163        .await?;
3164    let root_id = channel_model.root_id();
3165    let channel = Channel::from_model(channel_model);
3166
3167    response.send(proto::RenameChannelResponse {
3168        channel: Some(channel.to_proto()),
3169    })?;
3170
3171    let connection_pool = session.connection_pool().await;
3172    let update = proto::UpdateChannels {
3173        channels: vec![channel.to_proto()],
3174        ..Default::default()
3175    };
3176    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3177        if role.can_see_channel(channel.visibility) {
3178            session.peer.send(connection_id, update.clone())?;
3179        }
3180    }
3181
3182    Ok(())
3183}
3184
3185/// Move a channel to a new parent.
3186async fn move_channel(
3187    request: proto::MoveChannel,
3188    response: Response<proto::MoveChannel>,
3189    session: Session,
3190) -> Result<()> {
3191    let channel_id = ChannelId::from_proto(request.channel_id);
3192    let to = ChannelId::from_proto(request.to);
3193
3194    let (root_id, channels) = session
3195        .db()
3196        .await
3197        .move_channel(channel_id, to, session.user_id())
3198        .await?;
3199
3200    let connection_pool = session.connection_pool().await;
3201    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3202        let channels = channels
3203            .iter()
3204            .filter_map(|channel| {
3205                if role.can_see_channel(channel.visibility) {
3206                    Some(channel.to_proto())
3207                } else {
3208                    None
3209                }
3210            })
3211            .collect::<Vec<_>>();
3212        if channels.is_empty() {
3213            continue;
3214        }
3215
3216        let update = proto::UpdateChannels {
3217            channels,
3218            ..Default::default()
3219        };
3220
3221        session.peer.send(connection_id, update.clone())?;
3222    }
3223
3224    response.send(Ack {})?;
3225    Ok(())
3226}
3227
3228async fn reorder_channel(
3229    request: proto::ReorderChannel,
3230    response: Response<proto::ReorderChannel>,
3231    session: Session,
3232) -> Result<()> {
3233    let channel_id = ChannelId::from_proto(request.channel_id);
3234    let direction = request.direction();
3235
3236    let updated_channels = session
3237        .db()
3238        .await
3239        .reorder_channel(channel_id, direction, session.user_id())
3240        .await?;
3241
3242    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3243        let connection_pool = session.connection_pool().await;
3244        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3245            let channels = updated_channels
3246                .iter()
3247                .filter_map(|channel| {
3248                    if role.can_see_channel(channel.visibility) {
3249                        Some(channel.to_proto())
3250                    } else {
3251                        None
3252                    }
3253                })
3254                .collect::<Vec<_>>();
3255
3256            if channels.is_empty() {
3257                continue;
3258            }
3259
3260            let update = proto::UpdateChannels {
3261                channels,
3262                ..Default::default()
3263            };
3264
3265            session.peer.send(connection_id, update.clone())?;
3266        }
3267    }
3268
3269    response.send(Ack {})?;
3270    Ok(())
3271}
3272
3273/// Get the list of channel members
3274async fn get_channel_members(
3275    request: proto::GetChannelMembers,
3276    response: Response<proto::GetChannelMembers>,
3277    session: Session,
3278) -> Result<()> {
3279    let db = session.db().await;
3280    let channel_id = ChannelId::from_proto(request.channel_id);
3281    let limit = if request.limit == 0 {
3282        u16::MAX as u64
3283    } else {
3284        request.limit
3285    };
3286    let (members, users) = db
3287        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3288        .await?;
3289    response.send(proto::GetChannelMembersResponse { members, users })?;
3290    Ok(())
3291}
3292
3293/// Accept or decline a channel invitation.
3294async fn respond_to_channel_invite(
3295    request: proto::RespondToChannelInvite,
3296    response: Response<proto::RespondToChannelInvite>,
3297    session: Session,
3298) -> Result<()> {
3299    let db = session.db().await;
3300    let channel_id = ChannelId::from_proto(request.channel_id);
3301    let RespondToChannelInvite {
3302        membership_update,
3303        notifications,
3304    } = db
3305        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3306        .await?;
3307
3308    let mut connection_pool = session.connection_pool().await;
3309    if let Some(membership_update) = membership_update {
3310        notify_membership_updated(
3311            &mut connection_pool,
3312            membership_update,
3313            session.user_id(),
3314            &session.peer,
3315        );
3316    } else {
3317        let update = proto::UpdateChannels {
3318            remove_channel_invitations: vec![channel_id.to_proto()],
3319            ..Default::default()
3320        };
3321
3322        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3323            session.peer.send(connection_id, update.clone())?;
3324        }
3325    };
3326
3327    send_notifications(&connection_pool, &session.peer, notifications);
3328
3329    response.send(proto::Ack {})?;
3330
3331    Ok(())
3332}
3333
3334/// Join the channels' room
3335async fn join_channel(
3336    request: proto::JoinChannel,
3337    response: Response<proto::JoinChannel>,
3338    session: Session,
3339) -> Result<()> {
3340    let channel_id = ChannelId::from_proto(request.channel_id);
3341    join_channel_internal(channel_id, Box::new(response), session).await
3342}
3343
3344trait JoinChannelInternalResponse {
3345    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3346}
3347impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3348    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3349        Response::<proto::JoinChannel>::send(self, result)
3350    }
3351}
3352impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3353    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3354        Response::<proto::JoinRoom>::send(self, result)
3355    }
3356}
3357
3358async fn join_channel_internal(
3359    channel_id: ChannelId,
3360    response: Box<impl JoinChannelInternalResponse>,
3361    session: Session,
3362) -> Result<()> {
3363    let joined_room = {
3364        let mut db = session.db().await;
3365        // If zed quits without leaving the room, and the user re-opens zed before the
3366        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3367        // room they were in.
3368        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3369            tracing::info!(
3370                stale_connection_id = %connection,
3371                "cleaning up stale connection",
3372            );
3373            drop(db);
3374            leave_room_for_session(&session, connection).await?;
3375            db = session.db().await;
3376        }
3377
3378        let (joined_room, membership_updated, role) = db
3379            .join_channel(channel_id, session.user_id(), session.connection_id)
3380            .await?;
3381
3382        let live_kit_connection_info =
3383            session
3384                .app_state
3385                .livekit_client
3386                .as_ref()
3387                .and_then(|live_kit| {
3388                    let (can_publish, token) = if role == ChannelRole::Guest {
3389                        (
3390                            false,
3391                            live_kit
3392                                .guest_token(
3393                                    &joined_room.room.livekit_room,
3394                                    &session.user_id().to_string(),
3395                                )
3396                                .trace_err()?,
3397                        )
3398                    } else {
3399                        (
3400                            true,
3401                            live_kit
3402                                .room_token(
3403                                    &joined_room.room.livekit_room,
3404                                    &session.user_id().to_string(),
3405                                )
3406                                .trace_err()?,
3407                        )
3408                    };
3409
3410                    Some(LiveKitConnectionInfo {
3411                        server_url: live_kit.url().into(),
3412                        token,
3413                        can_publish,
3414                    })
3415                });
3416
3417        response.send(proto::JoinRoomResponse {
3418            room: Some(joined_room.room.clone()),
3419            channel_id: joined_room
3420                .channel
3421                .as_ref()
3422                .map(|channel| channel.id.to_proto()),
3423            live_kit_connection_info,
3424        })?;
3425
3426        let mut connection_pool = session.connection_pool().await;
3427        if let Some(membership_updated) = membership_updated {
3428            notify_membership_updated(
3429                &mut connection_pool,
3430                membership_updated,
3431                session.user_id(),
3432                &session.peer,
3433            );
3434        }
3435
3436        room_updated(&joined_room.room, &session.peer);
3437
3438        joined_room
3439    };
3440
3441    channel_updated(
3442        &joined_room.channel.context("channel not returned")?,
3443        &joined_room.room,
3444        &session.peer,
3445        &*session.connection_pool().await,
3446    );
3447
3448    update_user_contacts(session.user_id(), &session).await?;
3449    Ok(())
3450}
3451
3452/// Start editing the channel notes
3453async fn join_channel_buffer(
3454    request: proto::JoinChannelBuffer,
3455    response: Response<proto::JoinChannelBuffer>,
3456    session: Session,
3457) -> Result<()> {
3458    let db = session.db().await;
3459    let channel_id = ChannelId::from_proto(request.channel_id);
3460
3461    let open_response = db
3462        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3463        .await?;
3464
3465    let collaborators = open_response.collaborators.clone();
3466    response.send(open_response)?;
3467
3468    let update = UpdateChannelBufferCollaborators {
3469        channel_id: channel_id.to_proto(),
3470        collaborators: collaborators.clone(),
3471    };
3472    channel_buffer_updated(
3473        session.connection_id,
3474        collaborators
3475            .iter()
3476            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3477        &update,
3478        &session.peer,
3479    );
3480
3481    Ok(())
3482}
3483
3484/// Edit the channel notes
3485async fn update_channel_buffer(
3486    request: proto::UpdateChannelBuffer,
3487    session: Session,
3488) -> Result<()> {
3489    let db = session.db().await;
3490    let channel_id = ChannelId::from_proto(request.channel_id);
3491
3492    let (collaborators, epoch, version) = db
3493        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3494        .await?;
3495
3496    channel_buffer_updated(
3497        session.connection_id,
3498        collaborators.clone(),
3499        &proto::UpdateChannelBuffer {
3500            channel_id: channel_id.to_proto(),
3501            operations: request.operations,
3502        },
3503        &session.peer,
3504    );
3505
3506    let pool = &*session.connection_pool().await;
3507
3508    let non_collaborators =
3509        pool.channel_connection_ids(channel_id)
3510            .filter_map(|(connection_id, _)| {
3511                if collaborators.contains(&connection_id) {
3512                    None
3513                } else {
3514                    Some(connection_id)
3515                }
3516            });
3517
3518    broadcast(None, non_collaborators, |peer_id| {
3519        session.peer.send(
3520            peer_id,
3521            proto::UpdateChannels {
3522                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3523                    channel_id: channel_id.to_proto(),
3524                    epoch: epoch as u64,
3525                    version: version.clone(),
3526                }],
3527                ..Default::default()
3528            },
3529        )
3530    });
3531
3532    Ok(())
3533}
3534
3535/// Rejoin the channel notes after a connection blip
3536async fn rejoin_channel_buffers(
3537    request: proto::RejoinChannelBuffers,
3538    response: Response<proto::RejoinChannelBuffers>,
3539    session: Session,
3540) -> Result<()> {
3541    let db = session.db().await;
3542    let buffers = db
3543        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3544        .await?;
3545
3546    for rejoined_buffer in &buffers {
3547        let collaborators_to_notify = rejoined_buffer
3548            .buffer
3549            .collaborators
3550            .iter()
3551            .filter_map(|c| Some(c.peer_id?.into()));
3552        channel_buffer_updated(
3553            session.connection_id,
3554            collaborators_to_notify,
3555            &proto::UpdateChannelBufferCollaborators {
3556                channel_id: rejoined_buffer.buffer.channel_id,
3557                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3558            },
3559            &session.peer,
3560        );
3561    }
3562
3563    response.send(proto::RejoinChannelBuffersResponse {
3564        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3565    })?;
3566
3567    Ok(())
3568}
3569
3570/// Stop editing the channel notes
3571async fn leave_channel_buffer(
3572    request: proto::LeaveChannelBuffer,
3573    response: Response<proto::LeaveChannelBuffer>,
3574    session: Session,
3575) -> Result<()> {
3576    let db = session.db().await;
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578
3579    let left_buffer = db
3580        .leave_channel_buffer(channel_id, session.connection_id)
3581        .await?;
3582
3583    response.send(Ack {})?;
3584
3585    channel_buffer_updated(
3586        session.connection_id,
3587        left_buffer.connections,
3588        &proto::UpdateChannelBufferCollaborators {
3589            channel_id: channel_id.to_proto(),
3590            collaborators: left_buffer.collaborators,
3591        },
3592        &session.peer,
3593    );
3594
3595    Ok(())
3596}
3597
3598fn channel_buffer_updated<T: EnvelopedMessage>(
3599    sender_id: ConnectionId,
3600    collaborators: impl IntoIterator<Item = ConnectionId>,
3601    message: &T,
3602    peer: &Peer,
3603) {
3604    broadcast(Some(sender_id), collaborators, |peer_id| {
3605        peer.send(peer_id, message.clone())
3606    });
3607}
3608
3609fn send_notifications(
3610    connection_pool: &ConnectionPool,
3611    peer: &Peer,
3612    notifications: db::NotificationBatch,
3613) {
3614    for (user_id, notification) in notifications {
3615        for connection_id in connection_pool.user_connection_ids(user_id) {
3616            if let Err(error) = peer.send(
3617                connection_id,
3618                proto::AddNotification {
3619                    notification: Some(notification.clone()),
3620                },
3621            ) {
3622                tracing::error!(
3623                    "failed to send notification to {:?} {}",
3624                    connection_id,
3625                    error
3626                );
3627            }
3628        }
3629    }
3630}
3631
3632/// Send a message to the channel
3633async fn send_channel_message(
3634    request: proto::SendChannelMessage,
3635    response: Response<proto::SendChannelMessage>,
3636    session: Session,
3637) -> Result<()> {
3638    // Validate the message body.
3639    let body = request.body.trim().to_string();
3640    if body.len() > MAX_MESSAGE_LEN {
3641        return Err(anyhow!("message is too long"))?;
3642    }
3643    if body.is_empty() {
3644        return Err(anyhow!("message can't be blank"))?;
3645    }
3646
3647    // TODO: adjust mentions if body is trimmed
3648
3649    let timestamp = OffsetDateTime::now_utc();
3650    let nonce = request.nonce.context("nonce can't be blank")?;
3651
3652    let channel_id = ChannelId::from_proto(request.channel_id);
3653    let CreatedChannelMessage {
3654        message_id,
3655        participant_connection_ids,
3656        notifications,
3657    } = session
3658        .db()
3659        .await
3660        .create_channel_message(
3661            channel_id,
3662            session.user_id(),
3663            &body,
3664            &request.mentions,
3665            timestamp,
3666            nonce.clone().into(),
3667            request.reply_to_message_id.map(MessageId::from_proto),
3668        )
3669        .await?;
3670
3671    let message = proto::ChannelMessage {
3672        sender_id: session.user_id().to_proto(),
3673        id: message_id.to_proto(),
3674        body,
3675        mentions: request.mentions,
3676        timestamp: timestamp.unix_timestamp() as u64,
3677        nonce: Some(nonce),
3678        reply_to_message_id: request.reply_to_message_id,
3679        edited_at: None,
3680    };
3681    broadcast(
3682        Some(session.connection_id),
3683        participant_connection_ids.clone(),
3684        |connection| {
3685            session.peer.send(
3686                connection,
3687                proto::ChannelMessageSent {
3688                    channel_id: channel_id.to_proto(),
3689                    message: Some(message.clone()),
3690                },
3691            )
3692        },
3693    );
3694    response.send(proto::SendChannelMessageResponse {
3695        message: Some(message),
3696    })?;
3697
3698    let pool = &*session.connection_pool().await;
3699    let non_participants =
3700        pool.channel_connection_ids(channel_id)
3701            .filter_map(|(connection_id, _)| {
3702                if participant_connection_ids.contains(&connection_id) {
3703                    None
3704                } else {
3705                    Some(connection_id)
3706                }
3707            });
3708    broadcast(None, non_participants, |peer_id| {
3709        session.peer.send(
3710            peer_id,
3711            proto::UpdateChannels {
3712                latest_channel_message_ids: vec![proto::ChannelMessageId {
3713                    channel_id: channel_id.to_proto(),
3714                    message_id: message_id.to_proto(),
3715                }],
3716                ..Default::default()
3717            },
3718        )
3719    });
3720    send_notifications(pool, &session.peer, notifications);
3721
3722    Ok(())
3723}
3724
3725/// Delete a channel message
3726async fn remove_channel_message(
3727    request: proto::RemoveChannelMessage,
3728    response: Response<proto::RemoveChannelMessage>,
3729    session: Session,
3730) -> Result<()> {
3731    let channel_id = ChannelId::from_proto(request.channel_id);
3732    let message_id = MessageId::from_proto(request.message_id);
3733    let (connection_ids, existing_notification_ids) = session
3734        .db()
3735        .await
3736        .remove_channel_message(channel_id, message_id, session.user_id())
3737        .await?;
3738
3739    broadcast(
3740        Some(session.connection_id),
3741        connection_ids,
3742        move |connection| {
3743            session.peer.send(connection, request.clone())?;
3744
3745            for notification_id in &existing_notification_ids {
3746                session.peer.send(
3747                    connection,
3748                    proto::DeleteNotification {
3749                        notification_id: (*notification_id).to_proto(),
3750                    },
3751                )?;
3752            }
3753
3754            Ok(())
3755        },
3756    );
3757    response.send(proto::Ack {})?;
3758    Ok(())
3759}
3760
3761async fn update_channel_message(
3762    request: proto::UpdateChannelMessage,
3763    response: Response<proto::UpdateChannelMessage>,
3764    session: Session,
3765) -> Result<()> {
3766    let channel_id = ChannelId::from_proto(request.channel_id);
3767    let message_id = MessageId::from_proto(request.message_id);
3768    let updated_at = OffsetDateTime::now_utc();
3769    let UpdatedChannelMessage {
3770        message_id,
3771        participant_connection_ids,
3772        notifications,
3773        reply_to_message_id,
3774        timestamp,
3775        deleted_mention_notification_ids,
3776        updated_mention_notifications,
3777    } = session
3778        .db()
3779        .await
3780        .update_channel_message(
3781            channel_id,
3782            message_id,
3783            session.user_id(),
3784            request.body.as_str(),
3785            &request.mentions,
3786            updated_at,
3787        )
3788        .await?;
3789
3790    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3791
3792    let message = proto::ChannelMessage {
3793        sender_id: session.user_id().to_proto(),
3794        id: message_id.to_proto(),
3795        body: request.body.clone(),
3796        mentions: request.mentions.clone(),
3797        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3798        nonce: Some(nonce),
3799        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3800        edited_at: Some(updated_at.unix_timestamp() as u64),
3801    };
3802
3803    response.send(proto::Ack {})?;
3804
3805    let pool = &*session.connection_pool().await;
3806    broadcast(
3807        Some(session.connection_id),
3808        participant_connection_ids,
3809        |connection| {
3810            session.peer.send(
3811                connection,
3812                proto::ChannelMessageUpdate {
3813                    channel_id: channel_id.to_proto(),
3814                    message: Some(message.clone()),
3815                },
3816            )?;
3817
3818            for notification_id in &deleted_mention_notification_ids {
3819                session.peer.send(
3820                    connection,
3821                    proto::DeleteNotification {
3822                        notification_id: (*notification_id).to_proto(),
3823                    },
3824                )?;
3825            }
3826
3827            for notification in &updated_mention_notifications {
3828                session.peer.send(
3829                    connection,
3830                    proto::UpdateNotification {
3831                        notification: Some(notification.clone()),
3832                    },
3833                )?;
3834            }
3835
3836            Ok(())
3837        },
3838    );
3839
3840    send_notifications(pool, &session.peer, notifications);
3841
3842    Ok(())
3843}
3844
3845/// Mark a channel message as read
3846async fn acknowledge_channel_message(
3847    request: proto::AckChannelMessage,
3848    session: Session,
3849) -> Result<()> {
3850    let channel_id = ChannelId::from_proto(request.channel_id);
3851    let message_id = MessageId::from_proto(request.message_id);
3852    let notifications = session
3853        .db()
3854        .await
3855        .observe_channel_message(channel_id, session.user_id(), message_id)
3856        .await?;
3857    send_notifications(
3858        &*session.connection_pool().await,
3859        &session.peer,
3860        notifications,
3861    );
3862    Ok(())
3863}
3864
3865/// Mark a buffer version as synced
3866async fn acknowledge_buffer_version(
3867    request: proto::AckBufferOperation,
3868    session: Session,
3869) -> Result<()> {
3870    let buffer_id = BufferId::from_proto(request.buffer_id);
3871    session
3872        .db()
3873        .await
3874        .observe_buffer_version(
3875            buffer_id,
3876            session.user_id(),
3877            request.epoch as i32,
3878            &request.version,
3879        )
3880        .await?;
3881    Ok(())
3882}
3883
3884/// Get a Supermaven API key for the user
3885async fn get_supermaven_api_key(
3886    _request: proto::GetSupermavenApiKey,
3887    response: Response<proto::GetSupermavenApiKey>,
3888    session: Session,
3889) -> Result<()> {
3890    let user_id: String = session.user_id().to_string();
3891    if !session.is_staff() {
3892        return Err(anyhow!("supermaven not enabled for this account"))?;
3893    }
3894
3895    let email = session.email().context("user must have an email")?;
3896
3897    let supermaven_admin_api = session
3898        .supermaven_client
3899        .as_ref()
3900        .context("supermaven not configured")?;
3901
3902    let result = supermaven_admin_api
3903        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3904        .await?;
3905
3906    response.send(proto::GetSupermavenApiKeyResponse {
3907        api_key: result.api_key,
3908    })?;
3909
3910    Ok(())
3911}
3912
3913/// Start receiving chat updates for a channel
3914async fn join_channel_chat(
3915    request: proto::JoinChannelChat,
3916    response: Response<proto::JoinChannelChat>,
3917    session: Session,
3918) -> Result<()> {
3919    let channel_id = ChannelId::from_proto(request.channel_id);
3920
3921    let db = session.db().await;
3922    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3923        .await?;
3924    let messages = db
3925        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3926        .await?;
3927    response.send(proto::JoinChannelChatResponse {
3928        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3929        messages,
3930    })?;
3931    Ok(())
3932}
3933
3934/// Stop receiving chat updates for a channel
3935async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3936    let channel_id = ChannelId::from_proto(request.channel_id);
3937    session
3938        .db()
3939        .await
3940        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3941        .await?;
3942    Ok(())
3943}
3944
3945/// Retrieve the chat history for a channel
3946async fn get_channel_messages(
3947    request: proto::GetChannelMessages,
3948    response: Response<proto::GetChannelMessages>,
3949    session: Session,
3950) -> Result<()> {
3951    let channel_id = ChannelId::from_proto(request.channel_id);
3952    let messages = session
3953        .db()
3954        .await
3955        .get_channel_messages(
3956            channel_id,
3957            session.user_id(),
3958            MESSAGE_COUNT_PER_PAGE,
3959            Some(MessageId::from_proto(request.before_message_id)),
3960        )
3961        .await?;
3962    response.send(proto::GetChannelMessagesResponse {
3963        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3964        messages,
3965    })?;
3966    Ok(())
3967}
3968
3969/// Retrieve specific chat messages
3970async fn get_channel_messages_by_id(
3971    request: proto::GetChannelMessagesById,
3972    response: Response<proto::GetChannelMessagesById>,
3973    session: Session,
3974) -> Result<()> {
3975    let message_ids = request
3976        .message_ids
3977        .iter()
3978        .map(|id| MessageId::from_proto(*id))
3979        .collect::<Vec<_>>();
3980    let messages = session
3981        .db()
3982        .await
3983        .get_channel_messages_by_id(session.user_id(), &message_ids)
3984        .await?;
3985    response.send(proto::GetChannelMessagesResponse {
3986        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3987        messages,
3988    })?;
3989    Ok(())
3990}
3991
3992/// Retrieve the current users notifications
3993async fn get_notifications(
3994    request: proto::GetNotifications,
3995    response: Response<proto::GetNotifications>,
3996    session: Session,
3997) -> Result<()> {
3998    let notifications = session
3999        .db()
4000        .await
4001        .get_notifications(
4002            session.user_id(),
4003            NOTIFICATION_COUNT_PER_PAGE,
4004            request.before_id.map(db::NotificationId::from_proto),
4005        )
4006        .await?;
4007    response.send(proto::GetNotificationsResponse {
4008        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4009        notifications,
4010    })?;
4011    Ok(())
4012}
4013
4014/// Mark notifications as read
4015async fn mark_notification_as_read(
4016    request: proto::MarkNotificationRead,
4017    response: Response<proto::MarkNotificationRead>,
4018    session: Session,
4019) -> Result<()> {
4020    let database = &session.db().await;
4021    let notifications = database
4022        .mark_notification_as_read_by_id(
4023            session.user_id(),
4024            NotificationId::from_proto(request.notification_id),
4025        )
4026        .await?;
4027    send_notifications(
4028        &*session.connection_pool().await,
4029        &session.peer,
4030        notifications,
4031    );
4032    response.send(proto::Ack {})?;
4033    Ok(())
4034}
4035
4036/// Get the current users information
4037async fn get_private_user_info(
4038    _request: proto::GetPrivateUserInfo,
4039    response: Response<proto::GetPrivateUserInfo>,
4040    session: Session,
4041) -> Result<()> {
4042    let db = session.db().await;
4043
4044    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4045    let user = db
4046        .get_user_by_id(session.user_id())
4047        .await?
4048        .context("user not found")?;
4049    let flags = db.get_user_flags(session.user_id()).await?;
4050
4051    response.send(proto::GetPrivateUserInfoResponse {
4052        metrics_id,
4053        staff: user.admin,
4054        flags,
4055        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4056    })?;
4057    Ok(())
4058}
4059
4060/// Accept the terms of service (tos) on behalf of the current user
4061async fn accept_terms_of_service(
4062    _request: proto::AcceptTermsOfService,
4063    response: Response<proto::AcceptTermsOfService>,
4064    session: Session,
4065) -> Result<()> {
4066    let db = session.db().await;
4067
4068    let accepted_tos_at = Utc::now();
4069    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4070        .await?;
4071
4072    response.send(proto::AcceptTermsOfServiceResponse {
4073        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4074    })?;
4075    Ok(())
4076}
4077
4078/// The minimum account age an account must have in order to use the LLM service.
4079pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4080
4081async fn get_llm_api_token(
4082    _request: proto::GetLlmToken,
4083    response: Response<proto::GetLlmToken>,
4084    session: Session,
4085) -> Result<()> {
4086    let db = session.db().await;
4087
4088    let flags = db.get_user_flags(session.user_id()).await?;
4089
4090    let user_id = session.user_id();
4091    let user = db
4092        .get_user_by_id(user_id)
4093        .await?
4094        .with_context(|| format!("user {user_id} not found"))?;
4095
4096    if user.accepted_tos_at.is_none() {
4097        Err(anyhow!("terms of service not accepted"))?
4098    }
4099
4100    let stripe_client = session
4101        .app_state
4102        .stripe_client
4103        .as_ref()
4104        .context("failed to retrieve Stripe client")?;
4105
4106    let stripe_billing = session
4107        .app_state
4108        .stripe_billing
4109        .as_ref()
4110        .context("failed to retrieve Stripe billing object")?;
4111
4112    let billing_customer = if let Some(billing_customer) =
4113        db.get_billing_customer_by_user_id(user.id).await?
4114    {
4115        billing_customer
4116    } else {
4117        let customer_id = stripe_billing
4118            .find_or_create_customer_by_email(user.email_address.as_deref())
4119            .await?;
4120
4121        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4122            .await?
4123            .context("billing customer not found")?
4124    };
4125
4126    let billing_subscription =
4127        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4128            billing_subscription
4129        } else {
4130            let stripe_customer_id =
4131                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4132
4133            let stripe_subscription = stripe_billing
4134                .subscribe_to_zed_free(stripe_customer_id)
4135                .await?;
4136
4137            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4138                billing_customer_id: billing_customer.id,
4139                kind: Some(SubscriptionKind::ZedFree),
4140                stripe_subscription_id: stripe_subscription.id.to_string(),
4141                stripe_subscription_status: stripe_subscription.status.into(),
4142                stripe_cancellation_reason: None,
4143                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4144                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4145            })
4146            .await?
4147        };
4148
4149    let billing_preferences = db.get_billing_preferences(user.id).await?;
4150
4151    let token = LlmTokenClaims::create(
4152        &user,
4153        session.is_staff(),
4154        billing_customer,
4155        billing_preferences,
4156        &flags,
4157        billing_subscription,
4158        session.system_id.clone(),
4159        &session.app_state.config,
4160    )?;
4161    response.send(proto::GetLlmTokenResponse { token })?;
4162    Ok(())
4163}
4164
4165fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4166    let message = match message {
4167        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4168        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4169        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4170        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4171        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4172            code: frame.code.into(),
4173            reason: frame.reason.as_str().to_owned().into(),
4174        })),
4175        // We should never receive a frame while reading the message, according
4176        // to the `tungstenite` maintainers:
4177        //
4178        // > It cannot occur when you read messages from the WebSocket, but it
4179        // > can be used when you want to send the raw frames (e.g. you want to
4180        // > send the frames to the WebSocket without composing the full message first).
4181        // >
4182        // > — https://github.com/snapview/tungstenite-rs/issues/268
4183        TungsteniteMessage::Frame(_) => {
4184            bail!("received an unexpected frame while reading the message")
4185        }
4186    };
4187
4188    Ok(message)
4189}
4190
4191fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4192    match message {
4193        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4194        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4195        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4196        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4197        AxumMessage::Close(frame) => {
4198            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4199                code: frame.code.into(),
4200                reason: frame.reason.as_ref().into(),
4201            }))
4202        }
4203    }
4204}
4205
4206fn notify_membership_updated(
4207    connection_pool: &mut ConnectionPool,
4208    result: MembershipUpdated,
4209    user_id: UserId,
4210    peer: &Peer,
4211) {
4212    for membership in &result.new_channels.channel_memberships {
4213        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4214    }
4215    for channel_id in &result.removed_channels {
4216        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4217    }
4218
4219    let user_channels_update = proto::UpdateUserChannels {
4220        channel_memberships: result
4221            .new_channels
4222            .channel_memberships
4223            .iter()
4224            .map(|cm| proto::ChannelMembership {
4225                channel_id: cm.channel_id.to_proto(),
4226                role: cm.role.into(),
4227            })
4228            .collect(),
4229        ..Default::default()
4230    };
4231
4232    let mut update = build_channels_update(result.new_channels);
4233    update.delete_channels = result
4234        .removed_channels
4235        .into_iter()
4236        .map(|id| id.to_proto())
4237        .collect();
4238    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4239
4240    for connection_id in connection_pool.user_connection_ids(user_id) {
4241        peer.send(connection_id, user_channels_update.clone())
4242            .trace_err();
4243        peer.send(connection_id, update.clone()).trace_err();
4244    }
4245}
4246
4247fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4248    proto::UpdateUserChannels {
4249        channel_memberships: channels
4250            .channel_memberships
4251            .iter()
4252            .map(|m| proto::ChannelMembership {
4253                channel_id: m.channel_id.to_proto(),
4254                role: m.role.into(),
4255            })
4256            .collect(),
4257        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4258        observed_channel_message_id: channels.observed_channel_messages.clone(),
4259    }
4260}
4261
4262fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4263    let mut update = proto::UpdateChannels::default();
4264
4265    for channel in channels.channels {
4266        update.channels.push(channel.to_proto());
4267    }
4268
4269    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4270    update.latest_channel_message_ids = channels.latest_channel_messages;
4271
4272    for (channel_id, participants) in channels.channel_participants {
4273        update
4274            .channel_participants
4275            .push(proto::ChannelParticipants {
4276                channel_id: channel_id.to_proto(),
4277                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4278            });
4279    }
4280
4281    for channel in channels.invited_channels {
4282        update.channel_invitations.push(channel.to_proto());
4283    }
4284
4285    update
4286}
4287
4288fn build_initial_contacts_update(
4289    contacts: Vec<db::Contact>,
4290    pool: &ConnectionPool,
4291) -> proto::UpdateContacts {
4292    let mut update = proto::UpdateContacts::default();
4293
4294    for contact in contacts {
4295        match contact {
4296            db::Contact::Accepted { user_id, busy } => {
4297                update.contacts.push(contact_for_user(user_id, busy, pool));
4298            }
4299            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4300            db::Contact::Incoming { user_id } => {
4301                update
4302                    .incoming_requests
4303                    .push(proto::IncomingContactRequest {
4304                        requester_id: user_id.to_proto(),
4305                    })
4306            }
4307        }
4308    }
4309
4310    update
4311}
4312
4313fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4314    proto::Contact {
4315        user_id: user_id.to_proto(),
4316        online: pool.is_user_online(user_id),
4317        busy,
4318    }
4319}
4320
4321fn room_updated(room: &proto::Room, peer: &Peer) {
4322    broadcast(
4323        None,
4324        room.participants
4325            .iter()
4326            .filter_map(|participant| Some(participant.peer_id?.into())),
4327        |peer_id| {
4328            peer.send(
4329                peer_id,
4330                proto::RoomUpdated {
4331                    room: Some(room.clone()),
4332                },
4333            )
4334        },
4335    );
4336}
4337
4338fn channel_updated(
4339    channel: &db::channel::Model,
4340    room: &proto::Room,
4341    peer: &Peer,
4342    pool: &ConnectionPool,
4343) {
4344    let participants = room
4345        .participants
4346        .iter()
4347        .map(|p| p.user_id)
4348        .collect::<Vec<_>>();
4349
4350    broadcast(
4351        None,
4352        pool.channel_connection_ids(channel.root_id())
4353            .filter_map(|(channel_id, role)| {
4354                role.can_see_channel(channel.visibility)
4355                    .then_some(channel_id)
4356            }),
4357        |peer_id| {
4358            peer.send(
4359                peer_id,
4360                proto::UpdateChannels {
4361                    channel_participants: vec![proto::ChannelParticipants {
4362                        channel_id: channel.id.to_proto(),
4363                        participant_user_ids: participants.clone(),
4364                    }],
4365                    ..Default::default()
4366                },
4367            )
4368        },
4369    );
4370}
4371
4372async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4373    let db = session.db().await;
4374
4375    let contacts = db.get_contacts(user_id).await?;
4376    let busy = db.is_user_busy(user_id).await?;
4377
4378    let pool = session.connection_pool().await;
4379    let updated_contact = contact_for_user(user_id, busy, &pool);
4380    for contact in contacts {
4381        if let db::Contact::Accepted {
4382            user_id: contact_user_id,
4383            ..
4384        } = contact
4385        {
4386            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4387                session
4388                    .peer
4389                    .send(
4390                        contact_conn_id,
4391                        proto::UpdateContacts {
4392                            contacts: vec![updated_contact.clone()],
4393                            remove_contacts: Default::default(),
4394                            incoming_requests: Default::default(),
4395                            remove_incoming_requests: Default::default(),
4396                            outgoing_requests: Default::default(),
4397                            remove_outgoing_requests: Default::default(),
4398                        },
4399                    )
4400                    .trace_err();
4401            }
4402        }
4403    }
4404    Ok(())
4405}
4406
4407async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4408    let mut contacts_to_update = HashSet::default();
4409
4410    let room_id;
4411    let canceled_calls_to_user_ids;
4412    let livekit_room;
4413    let delete_livekit_room;
4414    let room;
4415    let channel;
4416
4417    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4418        contacts_to_update.insert(session.user_id());
4419
4420        for project in left_room.left_projects.values() {
4421            project_left(project, session);
4422        }
4423
4424        room_id = RoomId::from_proto(left_room.room.id);
4425        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4426        livekit_room = mem::take(&mut left_room.room.livekit_room);
4427        delete_livekit_room = left_room.deleted;
4428        room = mem::take(&mut left_room.room);
4429        channel = mem::take(&mut left_room.channel);
4430
4431        room_updated(&room, &session.peer);
4432    } else {
4433        return Ok(());
4434    }
4435
4436    if let Some(channel) = channel {
4437        channel_updated(
4438            &channel,
4439            &room,
4440            &session.peer,
4441            &*session.connection_pool().await,
4442        );
4443    }
4444
4445    {
4446        let pool = session.connection_pool().await;
4447        for canceled_user_id in canceled_calls_to_user_ids {
4448            for connection_id in pool.user_connection_ids(canceled_user_id) {
4449                session
4450                    .peer
4451                    .send(
4452                        connection_id,
4453                        proto::CallCanceled {
4454                            room_id: room_id.to_proto(),
4455                        },
4456                    )
4457                    .trace_err();
4458            }
4459            contacts_to_update.insert(canceled_user_id);
4460        }
4461    }
4462
4463    for contact_user_id in contacts_to_update {
4464        update_user_contacts(contact_user_id, session).await?;
4465    }
4466
4467    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4468        live_kit
4469            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4470            .await
4471            .trace_err();
4472
4473        if delete_livekit_room {
4474            live_kit.delete_room(livekit_room).await.trace_err();
4475        }
4476    }
4477
4478    Ok(())
4479}
4480
4481async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4482    let left_channel_buffers = session
4483        .db()
4484        .await
4485        .leave_channel_buffers(session.connection_id)
4486        .await?;
4487
4488    for left_buffer in left_channel_buffers {
4489        channel_buffer_updated(
4490            session.connection_id,
4491            left_buffer.connections,
4492            &proto::UpdateChannelBufferCollaborators {
4493                channel_id: left_buffer.channel_id.to_proto(),
4494                collaborators: left_buffer.collaborators,
4495            },
4496            &session.peer,
4497        );
4498    }
4499
4500    Ok(())
4501}
4502
4503fn project_left(project: &db::LeftProject, session: &Session) {
4504    for connection_id in &project.connection_ids {
4505        if project.should_unshare {
4506            session
4507                .peer
4508                .send(
4509                    *connection_id,
4510                    proto::UnshareProject {
4511                        project_id: project.id.to_proto(),
4512                    },
4513                )
4514                .trace_err();
4515        } else {
4516            session
4517                .peer
4518                .send(
4519                    *connection_id,
4520                    proto::RemoveProjectCollaborator {
4521                        project_id: project.id.to_proto(),
4522                        peer_id: Some(session.connection_id.into()),
4523                    },
4524                )
4525                .trace_err();
4526        }
4527    }
4528}
4529
4530pub trait ResultExt {
4531    type Ok;
4532
4533    fn trace_err(self) -> Option<Self::Ok>;
4534}
4535
4536impl<T, E> ResultExt for Result<T, E>
4537where
4538    E: std::fmt::Debug,
4539{
4540    type Ok = T;
4541
4542    #[track_caller]
4543    fn trace_err(self) -> Option<T> {
4544        match self {
4545            Ok(value) => Some(value),
4546            Err(error) => {
4547                tracing::error!("{:?}", error);
4548                None
4549            }
4550        }
4551    }
4552}