rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::stripe_client::StripeCustomerId;
   9use crate::{
  10    AppState, Error, Result, auth,
  11    db::{
  12        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  13        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  14        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  15        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  16    },
  17    executor::Executor,
  18};
  19use anyhow::{Context as _, anyhow, bail};
  20use async_tungstenite::tungstenite::{
  21    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  22};
  23use axum::{
  24    Extension, Router, TypedHeader,
  25    body::Body,
  26    extract::{
  27        ConnectInfo, WebSocketUpgrade,
  28        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  29    },
  30    headers::{Header, HeaderName},
  31    http::StatusCode,
  32    middleware,
  33    response::IntoResponse,
  34    routing::get,
  35};
  36use chrono::Utc;
  37use collections::{HashMap, HashSet};
  38pub use connection_pool::{ConnectionPool, ZedVersion};
  39use core::fmt::{self, Debug, Formatter};
  40use reqwest_client::ReqwestClient;
  41use rpc::proto::split_repository_update;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  46    stream::FuturesUnordered,
  47};
  48use prometheus::{IntGauge, register_int_gauge};
  49use rpc::{
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        Arc, OnceLock,
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{Semaphore, watch};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    Instrument,
  77    field::{self},
  78    info_span, instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111}
 112
 113impl Principal {
 114    fn user(&self) -> &User {
 115        match self {
 116            Principal::User(user) => user,
 117            Principal::Impersonated { user, .. } => user,
 118        }
 119    }
 120
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132        }
 133    }
 134}
 135
 136#[derive(Clone)]
 137struct Session {
 138    principal: Principal,
 139    connection_id: ConnectionId,
 140    db: Arc<tokio::sync::Mutex<DbHandle>>,
 141    peer: Arc<Peer>,
 142    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    system_id: Option<String>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn is_staff(&self) -> bool {
 173        match &self.principal {
 174            Principal::User(user) => user.admin,
 175            Principal::Impersonated { .. } => true,
 176        }
 177    }
 178
 179    fn user_id(&self) -> UserId {
 180        match &self.principal {
 181            Principal::User(user) => user.id,
 182            Principal::Impersonated { user, .. } => user.id,
 183        }
 184    }
 185
 186    pub fn email(&self) -> Option<String> {
 187        match &self.principal {
 188            Principal::User(user) => user.email_address.clone(),
 189            Principal::Impersonated { user, .. } => user.email_address.clone(),
 190        }
 191    }
 192}
 193
 194impl Debug for Session {
 195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 196        let mut result = f.debug_struct("Session");
 197        match &self.principal {
 198            Principal::User(user) => {
 199                result.field("user", &user.github_login);
 200            }
 201            Principal::Impersonated { user, admin } => {
 202                result.field("user", &user.github_login);
 203                result.field("impersonator", &admin.github_login);
 204            }
 205        }
 206        result.field("connection_id", &self.connection_id).finish()
 207    }
 208}
 209
 210struct DbHandle(Arc<Database>);
 211
 212impl Deref for DbHandle {
 213    type Target = Database;
 214
 215    fn deref(&self) -> &Self::Target {
 216        self.0.as_ref()
 217    }
 218}
 219
 220pub struct Server {
 221    id: parking_lot::Mutex<ServerId>,
 222    peer: Arc<Peer>,
 223    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 224    app_state: Arc<AppState>,
 225    handlers: HashMap<TypeId, MessageHandler>,
 226    teardown: watch::Sender<bool>,
 227}
 228
 229pub(crate) struct ConnectionPoolGuard<'a> {
 230    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 231    _not_send: PhantomData<Rc<()>>,
 232}
 233
 234#[derive(Serialize)]
 235pub struct ServerSnapshot<'a> {
 236    peer: &'a Peer,
 237    #[serde(serialize_with = "serialize_deref")]
 238    connection_pool: ConnectionPoolGuard<'a>,
 239}
 240
 241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 242where
 243    S: Serializer,
 244    T: Deref<Target = U>,
 245    U: Serialize,
 246{
 247    Serialize::serialize(value.deref(), serializer)
 248}
 249
 250impl Server {
 251    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 252        let mut server = Self {
 253            id: parking_lot::Mutex::new(id),
 254            peer: Peer::new(id.0 as u32),
 255            app_state: app_state.clone(),
 256            connection_pool: Default::default(),
 257            handlers: Default::default(),
 258            teardown: watch::channel(false).0,
 259        };
 260
 261        server
 262            .add_request_handler(ping)
 263            .add_request_handler(create_room)
 264            .add_request_handler(join_room)
 265            .add_request_handler(rejoin_room)
 266            .add_request_handler(leave_room)
 267            .add_request_handler(set_room_participant_role)
 268            .add_request_handler(call)
 269            .add_request_handler(cancel_call)
 270            .add_message_handler(decline_call)
 271            .add_request_handler(update_participant_location)
 272            .add_request_handler(share_project)
 273            .add_message_handler(unshare_project)
 274            .add_request_handler(join_project)
 275            .add_message_handler(leave_project)
 276            .add_request_handler(update_project)
 277            .add_request_handler(update_worktree)
 278            .add_request_handler(update_repository)
 279            .add_request_handler(remove_repository)
 280            .add_message_handler(start_language_server)
 281            .add_message_handler(update_language_server)
 282            .add_message_handler(update_diagnostic_summary)
 283            .add_message_handler(update_worktree_settings)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 288            .add_request_handler(forward_find_search_candidates_request)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 294            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 295            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 296            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 297            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 298            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 304            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 305            .add_request_handler(
 306                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 307            )
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 312            .add_request_handler(
 313                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 314            )
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 349            .add_message_handler(create_buffer_for_peer)
 350            .add_request_handler(update_buffer)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 357            .add_request_handler(get_users)
 358            .add_request_handler(fuzzy_search_users)
 359            .add_request_handler(request_contact)
 360            .add_request_handler(remove_contact)
 361            .add_request_handler(respond_to_contact_request)
 362            .add_message_handler(subscribe_to_channels)
 363            .add_request_handler(create_channel)
 364            .add_request_handler(delete_channel)
 365            .add_request_handler(invite_channel_member)
 366            .add_request_handler(remove_channel_member)
 367            .add_request_handler(set_channel_member_role)
 368            .add_request_handler(set_channel_visibility)
 369            .add_request_handler(rename_channel)
 370            .add_request_handler(join_channel_buffer)
 371            .add_request_handler(leave_channel_buffer)
 372            .add_message_handler(update_channel_buffer)
 373            .add_request_handler(rejoin_channel_buffers)
 374            .add_request_handler(get_channel_members)
 375            .add_request_handler(respond_to_channel_invite)
 376            .add_request_handler(join_channel)
 377            .add_request_handler(join_channel_chat)
 378            .add_message_handler(leave_channel_chat)
 379            .add_request_handler(send_channel_message)
 380            .add_request_handler(remove_channel_message)
 381            .add_request_handler(update_channel_message)
 382            .add_request_handler(get_channel_messages)
 383            .add_request_handler(get_channel_messages_by_id)
 384            .add_request_handler(get_notifications)
 385            .add_request_handler(mark_notification_as_read)
 386            .add_request_handler(move_channel)
 387            .add_request_handler(reorder_channel)
 388            .add_request_handler(follow)
 389            .add_message_handler(unfollow)
 390            .add_message_handler(update_followers)
 391            .add_request_handler(get_private_user_info)
 392            .add_request_handler(get_llm_api_token)
 393            .add_request_handler(accept_terms_of_service)
 394            .add_message_handler(acknowledge_channel_message)
 395            .add_message_handler(acknowledge_buffer_version)
 396            .add_request_handler(get_supermaven_api_key)
 397            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 398            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 399            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 402            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 403            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 406            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 408            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 409            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 410            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 412            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 415            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 416            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 417            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 418            .add_message_handler(update_context);
 419
 420        Arc::new(server)
 421    }
 422
 423    pub async fn start(&self) -> Result<()> {
 424        let server_id = *self.id.lock();
 425        let app_state = self.app_state.clone();
 426        let peer = self.peer.clone();
 427        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 428        let pool = self.connection_pool.clone();
 429        let livekit_client = self.app_state.livekit_client.clone();
 430
 431        let span = info_span!("start server");
 432        self.app_state.executor.spawn_detached(
 433            async move {
 434                tracing::info!("waiting for cleanup timeout");
 435                timeout.await;
 436                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 437
 438                app_state
 439                    .db
 440                    .delete_stale_channel_chat_participants(
 441                        &app_state.config.zed_environment,
 442                        server_id,
 443                    )
 444                    .await
 445                    .trace_err();
 446
 447                if let Some((room_ids, channel_ids)) = app_state
 448                    .db
 449                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 450                    .await
 451                    .trace_err()
 452                {
 453                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 454                    tracing::info!(
 455                        stale_channel_buffer_count = channel_ids.len(),
 456                        "retrieved stale channel buffers"
 457                    );
 458
 459                    for channel_id in channel_ids {
 460                        if let Some(refreshed_channel_buffer) = app_state
 461                            .db
 462                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 463                            .await
 464                            .trace_err()
 465                        {
 466                            for connection_id in refreshed_channel_buffer.connection_ids {
 467                                peer.send(
 468                                    connection_id,
 469                                    proto::UpdateChannelBufferCollaborators {
 470                                        channel_id: channel_id.to_proto(),
 471                                        collaborators: refreshed_channel_buffer
 472                                            .collaborators
 473                                            .clone(),
 474                                    },
 475                                )
 476                                .trace_err();
 477                            }
 478                        }
 479                    }
 480
 481                    for room_id in room_ids {
 482                        let mut contacts_to_update = HashSet::default();
 483                        let mut canceled_calls_to_user_ids = Vec::new();
 484                        let mut livekit_room = String::new();
 485                        let mut delete_livekit_room = false;
 486
 487                        if let Some(mut refreshed_room) = app_state
 488                            .db
 489                            .clear_stale_room_participants(room_id, server_id)
 490                            .await
 491                            .trace_err()
 492                        {
 493                            tracing::info!(
 494                                room_id = room_id.0,
 495                                new_participant_count = refreshed_room.room.participants.len(),
 496                                "refreshed room"
 497                            );
 498                            room_updated(&refreshed_room.room, &peer);
 499                            if let Some(channel) = refreshed_room.channel.as_ref() {
 500                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 501                            }
 502                            contacts_to_update
 503                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 504                            contacts_to_update
 505                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 506                            canceled_calls_to_user_ids =
 507                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 508                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 509                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 510                        }
 511
 512                        {
 513                            let pool = pool.lock();
 514                            for canceled_user_id in canceled_calls_to_user_ids {
 515                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 516                                    peer.send(
 517                                        connection_id,
 518                                        proto::CallCanceled {
 519                                            room_id: room_id.to_proto(),
 520                                        },
 521                                    )
 522                                    .trace_err();
 523                                }
 524                            }
 525                        }
 526
 527                        for user_id in contacts_to_update {
 528                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 529                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 530                            if let Some((busy, contacts)) = busy.zip(contacts) {
 531                                let pool = pool.lock();
 532                                let updated_contact = contact_for_user(user_id, busy, &pool);
 533                                for contact in contacts {
 534                                    if let db::Contact::Accepted {
 535                                        user_id: contact_user_id,
 536                                        ..
 537                                    } = contact
 538                                    {
 539                                        for contact_conn_id in
 540                                            pool.user_connection_ids(contact_user_id)
 541                                        {
 542                                            peer.send(
 543                                                contact_conn_id,
 544                                                proto::UpdateContacts {
 545                                                    contacts: vec![updated_contact.clone()],
 546                                                    remove_contacts: Default::default(),
 547                                                    incoming_requests: Default::default(),
 548                                                    remove_incoming_requests: Default::default(),
 549                                                    outgoing_requests: Default::default(),
 550                                                    remove_outgoing_requests: Default::default(),
 551                                                },
 552                                            )
 553                                            .trace_err();
 554                                        }
 555                                    }
 556                                }
 557                            }
 558                        }
 559
 560                        if let Some(live_kit) = livekit_client.as_ref() {
 561                            if delete_livekit_room {
 562                                live_kit.delete_room(livekit_room).await.trace_err();
 563                            }
 564                        }
 565                    }
 566                }
 567
 568                app_state
 569                    .db
 570                    .delete_stale_channel_chat_participants(
 571                        &app_state.config.zed_environment,
 572                        server_id,
 573                    )
 574                    .await
 575                    .trace_err();
 576
 577                app_state
 578                    .db
 579                    .clear_old_worktree_entries(server_id)
 580                    .await
 581                    .trace_err();
 582
 583                app_state
 584                    .db
 585                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 586                    .await
 587                    .trace_err();
 588            }
 589            .instrument(span),
 590        );
 591        Ok(())
 592    }
 593
 594    pub fn teardown(&self) {
 595        self.peer.teardown();
 596        self.connection_pool.lock().reset();
 597        let _ = self.teardown.send(true);
 598    }
 599
 600    #[cfg(test)]
 601    pub fn reset(&self, id: ServerId) {
 602        self.teardown();
 603        *self.id.lock() = id;
 604        self.peer.reset(id.0 as u32);
 605        let _ = self.teardown.send(false);
 606    }
 607
 608    #[cfg(test)]
 609    pub fn id(&self) -> ServerId {
 610        *self.id.lock()
 611    }
 612
 613    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 614    where
 615        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 616        Fut: 'static + Send + Future<Output = Result<()>>,
 617        M: EnvelopedMessage,
 618    {
 619        let prev_handler = self.handlers.insert(
 620            TypeId::of::<M>(),
 621            Box::new(move |envelope, session| {
 622                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 623                let received_at = envelope.received_at;
 624                tracing::info!("message received");
 625                let start_time = Instant::now();
 626                let future = (handler)(*envelope, session);
 627                async move {
 628                    let result = future.await;
 629                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 630                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 631                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 632                    let payload_type = M::NAME;
 633
 634                    match result {
 635                        Err(error) => {
 636                            tracing::error!(
 637                                ?error,
 638                                total_duration_ms,
 639                                processing_duration_ms,
 640                                queue_duration_ms,
 641                                payload_type,
 642                                "error handling message"
 643                            )
 644                        }
 645                        Ok(()) => tracing::info!(
 646                            total_duration_ms,
 647                            processing_duration_ms,
 648                            queue_duration_ms,
 649                            "finished handling message"
 650                        ),
 651                    }
 652                }
 653                .boxed()
 654            }),
 655        );
 656        if prev_handler.is_some() {
 657            panic!("registered a handler for the same message twice");
 658        }
 659        self
 660    }
 661
 662    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 663    where
 664        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 665        Fut: 'static + Send + Future<Output = Result<()>>,
 666        M: EnvelopedMessage,
 667    {
 668        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 669        self
 670    }
 671
 672    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 673    where
 674        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 675        Fut: Send + Future<Output = Result<()>>,
 676        M: RequestMessage,
 677    {
 678        let handler = Arc::new(handler);
 679        self.add_handler(move |envelope, session| {
 680            let receipt = envelope.receipt();
 681            let handler = handler.clone();
 682            async move {
 683                let peer = session.peer.clone();
 684                let responded = Arc::new(AtomicBool::default());
 685                let response = Response {
 686                    peer: peer.clone(),
 687                    responded: responded.clone(),
 688                    receipt,
 689                };
 690                match (handler)(envelope.payload, response, session).await {
 691                    Ok(()) => {
 692                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 693                            Ok(())
 694                        } else {
 695                            Err(anyhow!("handler did not send a response"))?
 696                        }
 697                    }
 698                    Err(error) => {
 699                        let proto_err = match &error {
 700                            Error::Internal(err) => err.to_proto(),
 701                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 702                        };
 703                        peer.respond_with_error(receipt, proto_err)?;
 704                        Err(error)
 705                    }
 706                }
 707            }
 708        })
 709    }
 710
 711    pub fn handle_connection(
 712        self: &Arc<Self>,
 713        connection: Connection,
 714        address: String,
 715        principal: Principal,
 716        zed_version: ZedVersion,
 717        geoip_country_code: Option<String>,
 718        system_id: Option<String>,
 719        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 720        executor: Executor,
 721    ) -> impl Future<Output = ()> + use<> {
 722        let this = self.clone();
 723        let span = info_span!("handle connection", %address,
 724            connection_id=field::Empty,
 725            user_id=field::Empty,
 726            login=field::Empty,
 727            impersonator=field::Empty,
 728            geoip_country_code=field::Empty
 729        );
 730        principal.update_span(&span);
 731        if let Some(country_code) = geoip_country_code.as_ref() {
 732            span.record("geoip_country_code", country_code);
 733        }
 734
 735        let mut teardown = self.teardown.subscribe();
 736        async move {
 737            if *teardown.borrow() {
 738                tracing::error!("server is tearing down");
 739                return
 740            }
 741            let (connection_id, handle_io, mut incoming_rx) = this
 742                .peer
 743                .add_connection(connection, {
 744                    let executor = executor.clone();
 745                    move |duration| executor.sleep(duration)
 746                });
 747            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 748
 749            tracing::info!("connection opened");
 750
 751            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 752            let http_client = match ReqwestClient::user_agent(&user_agent) {
 753                Ok(http_client) => Arc::new(http_client),
 754                Err(error) => {
 755                    tracing::error!(?error, "failed to create HTTP client");
 756                    return;
 757                }
 758            };
 759
 760            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 761                    supermaven_admin_api_key.to_string(),
 762                    http_client.clone(),
 763                )));
 764
 765            let session = Session {
 766                principal: principal.clone(),
 767                connection_id,
 768                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 769                peer: this.peer.clone(),
 770                connection_pool: this.connection_pool.clone(),
 771                app_state: this.app_state.clone(),
 772                geoip_country_code,
 773                system_id,
 774                _executor: executor.clone(),
 775                supermaven_client,
 776            };
 777
 778            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 779                tracing::error!(?error, "failed to send initial client update");
 780                return;
 781            }
 782
 783            let handle_io = handle_io.fuse();
 784            futures::pin_mut!(handle_io);
 785
 786            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 787            // This prevents deadlocks when e.g., client A performs a request to client B and
 788            // client B performs a request to client A. If both clients stop processing further
 789            // messages until their respective request completes, they won't have a chance to
 790            // respond to the other client's request and cause a deadlock.
 791            //
 792            // This arrangement ensures we will attempt to process earlier messages first, but fall
 793            // back to processing messages arrived later in the spirit of making progress.
 794            let mut foreground_message_handlers = FuturesUnordered::new();
 795            let concurrent_handlers = Arc::new(Semaphore::new(256));
 796            loop {
 797                let next_message = async {
 798                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 799                    let message = incoming_rx.next().await;
 800                    (permit, message)
 801                }.fuse();
 802                futures::pin_mut!(next_message);
 803                futures::select_biased! {
 804                    _ = teardown.changed().fuse() => return,
 805                    result = handle_io => {
 806                        if let Err(error) = result {
 807                            tracing::error!(?error, "error handling I/O");
 808                        }
 809                        break;
 810                    }
 811                    _ = foreground_message_handlers.next() => {}
 812                    next_message = next_message => {
 813                        let (permit, message) = next_message;
 814                        if let Some(message) = message {
 815                            let type_name = message.payload_type_name();
 816                            // note: we copy all the fields from the parent span so we can query them in the logs.
 817                            // (https://github.com/tokio-rs/tracing/issues/2670).
 818                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 819                                user_id=field::Empty,
 820                                login=field::Empty,
 821                                impersonator=field::Empty,
 822                            );
 823                            principal.update_span(&span);
 824                            let span_enter = span.enter();
 825                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 826                                let is_background = message.is_background();
 827                                let handle_message = (handler)(message, session.clone());
 828                                drop(span_enter);
 829
 830                                let handle_message = async move {
 831                                    handle_message.await;
 832                                    drop(permit);
 833                                }.instrument(span);
 834                                if is_background {
 835                                    executor.spawn_detached(handle_message);
 836                                } else {
 837                                    foreground_message_handlers.push(handle_message);
 838                                }
 839                            } else {
 840                                tracing::error!("no message handler");
 841                            }
 842                        } else {
 843                            tracing::info!("connection closed");
 844                            break;
 845                        }
 846                    }
 847                }
 848            }
 849
 850            drop(foreground_message_handlers);
 851            tracing::info!("signing out");
 852            if let Err(error) = connection_lost(session, teardown, executor).await {
 853                tracing::error!(?error, "error signing out");
 854            }
 855
 856        }.instrument(span)
 857    }
 858
 859    async fn send_initial_client_update(
 860        &self,
 861        connection_id: ConnectionId,
 862        zed_version: ZedVersion,
 863        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 864        session: &Session,
 865    ) -> Result<()> {
 866        self.peer.send(
 867            connection_id,
 868            proto::Hello {
 869                peer_id: Some(connection_id.into()),
 870            },
 871        )?;
 872        tracing::info!("sent hello message");
 873        if let Some(send_connection_id) = send_connection_id.take() {
 874            let _ = send_connection_id.send(connection_id);
 875        }
 876
 877        match &session.principal {
 878            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 879                if !user.connected_once {
 880                    self.peer.send(connection_id, proto::ShowContacts {})?;
 881                    self.app_state
 882                        .db
 883                        .set_user_connected_once(user.id, true)
 884                        .await?;
 885                }
 886
 887                update_user_plan(session).await?;
 888
 889                let contacts = self.app_state.db.get_contacts(user.id).await?;
 890
 891                {
 892                    let mut pool = self.connection_pool.lock();
 893                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 894                    self.peer.send(
 895                        connection_id,
 896                        build_initial_contacts_update(contacts, &pool),
 897                    )?;
 898                }
 899
 900                if should_auto_subscribe_to_channels(zed_version) {
 901                    subscribe_user_to_channels(user.id, session).await?;
 902                }
 903
 904                if let Some(incoming_call) =
 905                    self.app_state.db.incoming_call_for_user(user.id).await?
 906                {
 907                    self.peer.send(connection_id, incoming_call)?;
 908                }
 909
 910                update_user_contacts(user.id, session).await?;
 911            }
 912        }
 913
 914        Ok(())
 915    }
 916
 917    pub async fn invite_code_redeemed(
 918        self: &Arc<Self>,
 919        inviter_id: UserId,
 920        invitee_id: UserId,
 921    ) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 923            if let Some(code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 926                for connection_id in pool.user_connection_ids(inviter_id) {
 927                    self.peer.send(
 928                        connection_id,
 929                        proto::UpdateContacts {
 930                            contacts: vec![invitee_contact.clone()],
 931                            ..Default::default()
 932                        },
 933                    )?;
 934                    self.peer.send(
 935                        connection_id,
 936                        proto::UpdateInviteInfo {
 937                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 938                            count: user.invite_count as u32,
 939                        },
 940                    )?;
 941                }
 942            }
 943        }
 944        Ok(())
 945    }
 946
 947    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 948        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 949            if let Some(invite_code) = &user.invite_code {
 950                let pool = self.connection_pool.lock();
 951                for connection_id in pool.user_connection_ids(user_id) {
 952                    self.peer.send(
 953                        connection_id,
 954                        proto::UpdateInviteInfo {
 955                            url: format!(
 956                                "{}{}",
 957                                self.app_state.config.invite_link_prefix, invite_code
 958                            ),
 959                            count: user.invite_count as u32,
 960                        },
 961                    )?;
 962                }
 963            }
 964        }
 965        Ok(())
 966    }
 967
 968    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 969        let user = self
 970            .app_state
 971            .db
 972            .get_user_by_id(user_id)
 973            .await?
 974            .context("user not found")?;
 975
 976        let update_user_plan = make_update_user_plan_message(
 977            &user,
 978            user.admin,
 979            &self.app_state.db,
 980            self.app_state.llm_db.clone(),
 981        )
 982        .await?;
 983
 984        let pool = self.connection_pool.lock();
 985        for connection_id in pool.user_connection_ids(user_id) {
 986            self.peer
 987                .send(connection_id, update_user_plan.clone())
 988                .trace_err();
 989        }
 990
 991        Ok(())
 992    }
 993
 994    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 995        let pool = self.connection_pool.lock();
 996        for connection_id in pool.user_connection_ids(user_id) {
 997            self.peer
 998                .send(connection_id, proto::RefreshLlmToken {})
 999                .trace_err();
1000        }
1001    }
1002
1003    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1004        ServerSnapshot {
1005            connection_pool: ConnectionPoolGuard {
1006                guard: self.connection_pool.lock(),
1007                _not_send: PhantomData,
1008            },
1009            peer: &self.peer,
1010        }
1011    }
1012}
1013
1014impl Deref for ConnectionPoolGuard<'_> {
1015    type Target = ConnectionPool;
1016
1017    fn deref(&self) -> &Self::Target {
1018        &self.guard
1019    }
1020}
1021
1022impl DerefMut for ConnectionPoolGuard<'_> {
1023    fn deref_mut(&mut self) -> &mut Self::Target {
1024        &mut self.guard
1025    }
1026}
1027
1028impl Drop for ConnectionPoolGuard<'_> {
1029    fn drop(&mut self) {
1030        #[cfg(test)]
1031        self.check_invariants();
1032    }
1033}
1034
1035fn broadcast<F>(
1036    sender_id: Option<ConnectionId>,
1037    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1038    mut f: F,
1039) where
1040    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1041{
1042    for receiver_id in receiver_ids {
1043        if Some(receiver_id) != sender_id {
1044            if let Err(error) = f(receiver_id) {
1045                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1046            }
1047        }
1048    }
1049}
1050
1051pub struct ProtocolVersion(u32);
1052
1053impl Header for ProtocolVersion {
1054    fn name() -> &'static HeaderName {
1055        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1057    }
1058
1059    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060    where
1061        Self: Sized,
1062        I: Iterator<Item = &'i axum::http::HeaderValue>,
1063    {
1064        let version = values
1065            .next()
1066            .ok_or_else(axum::headers::Error::invalid)?
1067            .to_str()
1068            .map_err(|_| axum::headers::Error::invalid())?
1069            .parse()
1070            .map_err(|_| axum::headers::Error::invalid())?;
1071        Ok(Self(version))
1072    }
1073
1074    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075        values.extend([self.0.to_string().parse().unwrap()]);
1076    }
1077}
1078
1079pub struct AppVersionHeader(SemanticVersion);
1080impl Header for AppVersionHeader {
1081    fn name() -> &'static HeaderName {
1082        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1083        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1084    }
1085
1086    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1087    where
1088        Self: Sized,
1089        I: Iterator<Item = &'i axum::http::HeaderValue>,
1090    {
1091        let version = values
1092            .next()
1093            .ok_or_else(axum::headers::Error::invalid)?
1094            .to_str()
1095            .map_err(|_| axum::headers::Error::invalid())?
1096            .parse()
1097            .map_err(|_| axum::headers::Error::invalid())?;
1098        Ok(Self(version))
1099    }
1100
1101    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1102        values.extend([self.0.to_string().parse().unwrap()]);
1103    }
1104}
1105
1106pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1107    Router::new()
1108        .route("/rpc", get(handle_websocket_request))
1109        .layer(
1110            ServiceBuilder::new()
1111                .layer(Extension(server.app_state.clone()))
1112                .layer(middleware::from_fn(auth::validate_header)),
1113        )
1114        .route("/metrics", get(handle_metrics))
1115        .layer(Extension(server))
1116}
1117
1118pub async fn handle_websocket_request(
1119    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1120    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1121    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1122    Extension(server): Extension<Arc<Server>>,
1123    Extension(principal): Extension<Principal>,
1124    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1125    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1126    ws: WebSocketUpgrade,
1127) -> axum::response::Response {
1128    if protocol_version != rpc::PROTOCOL_VERSION {
1129        return (
1130            StatusCode::UPGRADE_REQUIRED,
1131            "client must be upgraded".to_string(),
1132        )
1133            .into_response();
1134    }
1135
1136    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1137        return (
1138            StatusCode::UPGRADE_REQUIRED,
1139            "no version header found".to_string(),
1140        )
1141            .into_response();
1142    };
1143
1144    if !version.can_collaborate() {
1145        return (
1146            StatusCode::UPGRADE_REQUIRED,
1147            "client must be upgraded".to_string(),
1148        )
1149            .into_response();
1150    }
1151
1152    let socket_address = socket_address.to_string();
1153    ws.on_upgrade(move |socket| {
1154        let socket = socket
1155            .map_ok(to_tungstenite_message)
1156            .err_into()
1157            .with(|message| async move { to_axum_message(message) });
1158        let connection = Connection::new(Box::pin(socket));
1159        async move {
1160            server
1161                .handle_connection(
1162                    connection,
1163                    socket_address,
1164                    principal,
1165                    version,
1166                    country_code_header.map(|header| header.to_string()),
1167                    system_id_header.map(|header| header.to_string()),
1168                    None,
1169                    Executor::Production,
1170                )
1171                .await;
1172        }
1173    })
1174}
1175
1176pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1177    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1178    let connections_metric = CONNECTIONS_METRIC
1179        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1180
1181    let connections = server
1182        .connection_pool
1183        .lock()
1184        .connections()
1185        .filter(|connection| !connection.admin)
1186        .count();
1187    connections_metric.set(connections as _);
1188
1189    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1190    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1191        register_int_gauge!(
1192            "shared_projects",
1193            "number of open projects with one or more guests"
1194        )
1195        .unwrap()
1196    });
1197
1198    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1199    shared_projects_metric.set(shared_projects as _);
1200
1201    let encoder = prometheus::TextEncoder::new();
1202    let metric_families = prometheus::gather();
1203    let encoded_metrics = encoder
1204        .encode_to_string(&metric_families)
1205        .map_err(|err| anyhow!("{err}"))?;
1206    Ok(encoded_metrics)
1207}
1208
1209#[instrument(err, skip(executor))]
1210async fn connection_lost(
1211    session: Session,
1212    mut teardown: watch::Receiver<bool>,
1213    executor: Executor,
1214) -> Result<()> {
1215    session.peer.disconnect(session.connection_id);
1216    session
1217        .connection_pool()
1218        .await
1219        .remove_connection(session.connection_id)?;
1220
1221    session
1222        .db()
1223        .await
1224        .connection_lost(session.connection_id)
1225        .await
1226        .trace_err();
1227
1228    futures::select_biased! {
1229        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1230
1231            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1232            leave_room_for_session(&session, session.connection_id).await.trace_err();
1233            leave_channel_buffers_for_session(&session)
1234                .await
1235                .trace_err();
1236
1237            if !session
1238                .connection_pool()
1239                .await
1240                .is_user_online(session.user_id())
1241            {
1242                let db = session.db().await;
1243                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1244                    room_updated(&room, &session.peer);
1245                }
1246            }
1247
1248            update_user_contacts(session.user_id(), &session).await?;
1249        },
1250        _ = teardown.changed().fuse() => {}
1251    }
1252
1253    Ok(())
1254}
1255
1256/// Acknowledges a ping from a client, used to keep the connection alive.
1257async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1258    response.send(proto::Ack {})?;
1259    Ok(())
1260}
1261
1262/// Creates a new room for calling (outside of channels)
1263async fn create_room(
1264    _request: proto::CreateRoom,
1265    response: Response<proto::CreateRoom>,
1266    session: Session,
1267) -> Result<()> {
1268    let livekit_room = nanoid::nanoid!(30);
1269
1270    let live_kit_connection_info = util::maybe!(async {
1271        let live_kit = session.app_state.livekit_client.as_ref();
1272        let live_kit = live_kit?;
1273        let user_id = session.user_id().to_string();
1274
1275        let token = live_kit
1276            .room_token(&livekit_room, &user_id.to_string())
1277            .trace_err()?;
1278
1279        Some(proto::LiveKitConnectionInfo {
1280            server_url: live_kit.url().into(),
1281            token,
1282            can_publish: true,
1283        })
1284    })
1285    .await;
1286
1287    let room = session
1288        .db()
1289        .await
1290        .create_room(session.user_id(), session.connection_id, &livekit_room)
1291        .await?;
1292
1293    response.send(proto::CreateRoomResponse {
1294        room: Some(room.clone()),
1295        live_kit_connection_info,
1296    })?;
1297
1298    update_user_contacts(session.user_id(), &session).await?;
1299    Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304    request: proto::JoinRoom,
1305    response: Response<proto::JoinRoom>,
1306    session: Session,
1307) -> Result<()> {
1308    let room_id = RoomId::from_proto(request.id);
1309
1310    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312    if let Some(channel_id) = channel_id {
1313        return join_channel_internal(channel_id, Box::new(response), session).await;
1314    }
1315
1316    let joined_room = {
1317        let room = session
1318            .db()
1319            .await
1320            .join_room(room_id, session.user_id(), session.connection_id)
1321            .await?;
1322        room_updated(&room.room, &session.peer);
1323        room.into_inner()
1324    };
1325
1326    for connection_id in session
1327        .connection_pool()
1328        .await
1329        .user_connection_ids(session.user_id())
1330    {
1331        session
1332            .peer
1333            .send(
1334                connection_id,
1335                proto::CallCanceled {
1336                    room_id: room_id.to_proto(),
1337                },
1338            )
1339            .trace_err();
1340    }
1341
1342    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343    {
1344        live_kit
1345            .room_token(
1346                &joined_room.room.livekit_room,
1347                &session.user_id().to_string(),
1348            )
1349            .trace_err()
1350            .map(|token| proto::LiveKitConnectionInfo {
1351                server_url: live_kit.url().into(),
1352                token,
1353                can_publish: true,
1354            })
1355    } else {
1356        None
1357    };
1358
1359    response.send(proto::JoinRoomResponse {
1360        room: Some(joined_room.room),
1361        channel_id: None,
1362        live_kit_connection_info,
1363    })?;
1364
1365    update_user_contacts(session.user_id(), &session).await?;
1366    Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371    request: proto::RejoinRoom,
1372    response: Response<proto::RejoinRoom>,
1373    session: Session,
1374) -> Result<()> {
1375    let room;
1376    let channel;
1377    {
1378        let mut rejoined_room = session
1379            .db()
1380            .await
1381            .rejoin_room(request, session.user_id(), session.connection_id)
1382            .await?;
1383
1384        response.send(proto::RejoinRoomResponse {
1385            room: Some(rejoined_room.room.clone()),
1386            reshared_projects: rejoined_room
1387                .reshared_projects
1388                .iter()
1389                .map(|project| proto::ResharedProject {
1390                    id: project.id.to_proto(),
1391                    collaborators: project
1392                        .collaborators
1393                        .iter()
1394                        .map(|collaborator| collaborator.to_proto())
1395                        .collect(),
1396                })
1397                .collect(),
1398            rejoined_projects: rejoined_room
1399                .rejoined_projects
1400                .iter()
1401                .map(|rejoined_project| rejoined_project.to_proto())
1402                .collect(),
1403        })?;
1404        room_updated(&rejoined_room.room, &session.peer);
1405
1406        for project in &rejoined_room.reshared_projects {
1407            for collaborator in &project.collaborators {
1408                session
1409                    .peer
1410                    .send(
1411                        collaborator.connection_id,
1412                        proto::UpdateProjectCollaborator {
1413                            project_id: project.id.to_proto(),
1414                            old_peer_id: Some(project.old_connection_id.into()),
1415                            new_peer_id: Some(session.connection_id.into()),
1416                        },
1417                    )
1418                    .trace_err();
1419            }
1420
1421            broadcast(
1422                Some(session.connection_id),
1423                project
1424                    .collaborators
1425                    .iter()
1426                    .map(|collaborator| collaborator.connection_id),
1427                |connection_id| {
1428                    session.peer.forward_send(
1429                        session.connection_id,
1430                        connection_id,
1431                        proto::UpdateProject {
1432                            project_id: project.id.to_proto(),
1433                            worktrees: project.worktrees.clone(),
1434                        },
1435                    )
1436                },
1437            );
1438        }
1439
1440        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442        let rejoined_room = rejoined_room.into_inner();
1443
1444        room = rejoined_room.room;
1445        channel = rejoined_room.channel;
1446    }
1447
1448    if let Some(channel) = channel {
1449        channel_updated(
1450            &channel,
1451            &room,
1452            &session.peer,
1453            &*session.connection_pool().await,
1454        );
1455    }
1456
1457    update_user_contacts(session.user_id(), &session).await?;
1458    Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462    rejoined_projects: &mut Vec<RejoinedProject>,
1463    session: &Session,
1464) -> Result<()> {
1465    for project in rejoined_projects.iter() {
1466        for collaborator in &project.collaborators {
1467            session
1468                .peer
1469                .send(
1470                    collaborator.connection_id,
1471                    proto::UpdateProjectCollaborator {
1472                        project_id: project.id.to_proto(),
1473                        old_peer_id: Some(project.old_connection_id.into()),
1474                        new_peer_id: Some(session.connection_id.into()),
1475                    },
1476                )
1477                .trace_err();
1478        }
1479    }
1480
1481    for project in rejoined_projects {
1482        for worktree in mem::take(&mut project.worktrees) {
1483            // Stream this worktree's entries.
1484            let message = proto::UpdateWorktree {
1485                project_id: project.id.to_proto(),
1486                worktree_id: worktree.id,
1487                abs_path: worktree.abs_path.clone(),
1488                root_name: worktree.root_name,
1489                updated_entries: worktree.updated_entries,
1490                removed_entries: worktree.removed_entries,
1491                scan_id: worktree.scan_id,
1492                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493                updated_repositories: worktree.updated_repositories,
1494                removed_repositories: worktree.removed_repositories,
1495            };
1496            for update in proto::split_worktree_update(message) {
1497                session.peer.send(session.connection_id, update)?;
1498            }
1499
1500            // Stream this worktree's diagnostics.
1501            for summary in worktree.diagnostic_summaries {
1502                session.peer.send(
1503                    session.connection_id,
1504                    proto::UpdateDiagnosticSummary {
1505                        project_id: project.id.to_proto(),
1506                        worktree_id: worktree.id,
1507                        summary: Some(summary),
1508                    },
1509                )?;
1510            }
1511
1512            for settings_file in worktree.settings_files {
1513                session.peer.send(
1514                    session.connection_id,
1515                    proto::UpdateWorktreeSettings {
1516                        project_id: project.id.to_proto(),
1517                        worktree_id: worktree.id,
1518                        path: settings_file.path,
1519                        content: Some(settings_file.content),
1520                        kind: Some(settings_file.kind.to_proto().into()),
1521                    },
1522                )?;
1523            }
1524        }
1525
1526        for repository in mem::take(&mut project.updated_repositories) {
1527            for update in split_repository_update(repository) {
1528                session.peer.send(session.connection_id, update)?;
1529            }
1530        }
1531
1532        for id in mem::take(&mut project.removed_repositories) {
1533            session.peer.send(
1534                session.connection_id,
1535                proto::RemoveRepository {
1536                    project_id: project.id.to_proto(),
1537                    id,
1538                },
1539            )?;
1540        }
1541    }
1542
1543    Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548    _: proto::LeaveRoom,
1549    response: Response<proto::LeaveRoom>,
1550    session: Session,
1551) -> Result<()> {
1552    leave_room_for_session(&session, session.connection_id).await?;
1553    response.send(proto::Ack {})?;
1554    Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559    request: proto::SetRoomParticipantRole,
1560    response: Response<proto::SetRoomParticipantRole>,
1561    session: Session,
1562) -> Result<()> {
1563    let user_id = UserId::from_proto(request.user_id);
1564    let role = ChannelRole::from(request.role());
1565
1566    let (livekit_room, can_publish) = {
1567        let room = session
1568            .db()
1569            .await
1570            .set_room_participant_role(
1571                session.user_id(),
1572                RoomId::from_proto(request.room_id),
1573                user_id,
1574                role,
1575            )
1576            .await?;
1577
1578        let livekit_room = room.livekit_room.clone();
1579        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580        room_updated(&room, &session.peer);
1581        (livekit_room, can_publish)
1582    };
1583
1584    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585        live_kit
1586            .update_participant(
1587                livekit_room.clone(),
1588                request.user_id.to_string(),
1589                livekit_api::proto::ParticipantPermission {
1590                    can_subscribe: true,
1591                    can_publish,
1592                    can_publish_data: can_publish,
1593                    hidden: false,
1594                    recorder: false,
1595                },
1596            )
1597            .await
1598            .trace_err();
1599    }
1600
1601    response.send(proto::Ack {})?;
1602    Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607    request: proto::Call,
1608    response: Response<proto::Call>,
1609    session: Session,
1610) -> Result<()> {
1611    let room_id = RoomId::from_proto(request.room_id);
1612    let calling_user_id = session.user_id();
1613    let calling_connection_id = session.connection_id;
1614    let called_user_id = UserId::from_proto(request.called_user_id);
1615    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616    if !session
1617        .db()
1618        .await
1619        .has_contact(calling_user_id, called_user_id)
1620        .await?
1621    {
1622        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623    }
1624
1625    let incoming_call = {
1626        let (room, incoming_call) = &mut *session
1627            .db()
1628            .await
1629            .call(
1630                room_id,
1631                calling_user_id,
1632                calling_connection_id,
1633                called_user_id,
1634                initial_project_id,
1635            )
1636            .await?;
1637        room_updated(room, &session.peer);
1638        mem::take(incoming_call)
1639    };
1640    update_user_contacts(called_user_id, &session).await?;
1641
1642    let mut calls = session
1643        .connection_pool()
1644        .await
1645        .user_connection_ids(called_user_id)
1646        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647        .collect::<FuturesUnordered<_>>();
1648
1649    while let Some(call_response) = calls.next().await {
1650        match call_response.as_ref() {
1651            Ok(_) => {
1652                response.send(proto::Ack {})?;
1653                return Ok(());
1654            }
1655            Err(_) => {
1656                call_response.trace_err();
1657            }
1658        }
1659    }
1660
1661    {
1662        let room = session
1663            .db()
1664            .await
1665            .call_failed(room_id, called_user_id)
1666            .await?;
1667        room_updated(&room, &session.peer);
1668    }
1669    update_user_contacts(called_user_id, &session).await?;
1670
1671    Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676    request: proto::CancelCall,
1677    response: Response<proto::CancelCall>,
1678    session: Session,
1679) -> Result<()> {
1680    let called_user_id = UserId::from_proto(request.called_user_id);
1681    let room_id = RoomId::from_proto(request.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .cancel_call(room_id, session.connection_id, called_user_id)
1687            .await?;
1688        room_updated(&room, &session.peer);
1689    }
1690
1691    for connection_id in session
1692        .connection_pool()
1693        .await
1694        .user_connection_ids(called_user_id)
1695    {
1696        session
1697            .peer
1698            .send(
1699                connection_id,
1700                proto::CallCanceled {
1701                    room_id: room_id.to_proto(),
1702                },
1703            )
1704            .trace_err();
1705    }
1706    response.send(proto::Ack {})?;
1707
1708    update_user_contacts(called_user_id, &session).await?;
1709    Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1714    let room_id = RoomId::from_proto(message.room_id);
1715    {
1716        let room = session
1717            .db()
1718            .await
1719            .decline_call(Some(room_id), session.user_id())
1720            .await?
1721            .context("declining call")?;
1722        room_updated(&room, &session.peer);
1723    }
1724
1725    for connection_id in session
1726        .connection_pool()
1727        .await
1728        .user_connection_ids(session.user_id())
1729    {
1730        session
1731            .peer
1732            .send(
1733                connection_id,
1734                proto::CallCanceled {
1735                    room_id: room_id.to_proto(),
1736                },
1737            )
1738            .trace_err();
1739    }
1740    update_user_contacts(session.user_id(), &session).await?;
1741    Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746    request: proto::UpdateParticipantLocation,
1747    response: Response<proto::UpdateParticipantLocation>,
1748    session: Session,
1749) -> Result<()> {
1750    let room_id = RoomId::from_proto(request.room_id);
1751    let location = request.location.context("invalid location")?;
1752
1753    let db = session.db().await;
1754    let room = db
1755        .update_room_participant_location(room_id, session.connection_id, location)
1756        .await?;
1757
1758    room_updated(&room, &session.peer);
1759    response.send(proto::Ack {})?;
1760    Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765    request: proto::ShareProject,
1766    response: Response<proto::ShareProject>,
1767    session: Session,
1768) -> Result<()> {
1769    let (project_id, room) = &*session
1770        .db()
1771        .await
1772        .share_project(
1773            RoomId::from_proto(request.room_id),
1774            session.connection_id,
1775            &request.worktrees,
1776            request.is_ssh_project,
1777        )
1778        .await?;
1779    response.send(proto::ShareProjectResponse {
1780        project_id: project_id.to_proto(),
1781    })?;
1782    room_updated(room, &session.peer);
1783
1784    Ok(())
1785}
1786
1787/// Unshare a project from the room.
1788async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1789    let project_id = ProjectId::from_proto(message.project_id);
1790    unshare_project_internal(project_id, session.connection_id, &session).await
1791}
1792
1793async fn unshare_project_internal(
1794    project_id: ProjectId,
1795    connection_id: ConnectionId,
1796    session: &Session,
1797) -> Result<()> {
1798    let delete = {
1799        let room_guard = session
1800            .db()
1801            .await
1802            .unshare_project(project_id, connection_id)
1803            .await?;
1804
1805        let (delete, room, guest_connection_ids) = &*room_guard;
1806
1807        let message = proto::UnshareProject {
1808            project_id: project_id.to_proto(),
1809        };
1810
1811        broadcast(
1812            Some(connection_id),
1813            guest_connection_ids.iter().copied(),
1814            |conn_id| session.peer.send(conn_id, message.clone()),
1815        );
1816        if let Some(room) = room {
1817            room_updated(room, &session.peer);
1818        }
1819
1820        *delete
1821    };
1822
1823    if delete {
1824        let db = session.db().await;
1825        db.delete_project(project_id).await?;
1826    }
1827
1828    Ok(())
1829}
1830
1831/// Join someone elses shared project.
1832async fn join_project(
1833    request: proto::JoinProject,
1834    response: Response<proto::JoinProject>,
1835    session: Session,
1836) -> Result<()> {
1837    let project_id = ProjectId::from_proto(request.project_id);
1838
1839    tracing::info!(%project_id, "join project");
1840
1841    let db = session.db().await;
1842    let (project, replica_id) = &mut *db
1843        .join_project(project_id, session.connection_id, session.user_id())
1844        .await?;
1845    drop(db);
1846    tracing::info!(%project_id, "join remote project");
1847    join_project_internal(response, session, project, replica_id)
1848}
1849
1850trait JoinProjectInternalResponse {
1851    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1852}
1853impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1854    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1855        Response::<proto::JoinProject>::send(self, result)
1856    }
1857}
1858
1859fn join_project_internal(
1860    response: impl JoinProjectInternalResponse,
1861    session: Session,
1862    project: &mut Project,
1863    replica_id: &ReplicaId,
1864) -> Result<()> {
1865    let collaborators = project
1866        .collaborators
1867        .iter()
1868        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1869        .map(|collaborator| collaborator.to_proto())
1870        .collect::<Vec<_>>();
1871    let project_id = project.id;
1872    let guest_user_id = session.user_id();
1873
1874    let worktrees = project
1875        .worktrees
1876        .iter()
1877        .map(|(id, worktree)| proto::WorktreeMetadata {
1878            id: *id,
1879            root_name: worktree.root_name.clone(),
1880            visible: worktree.visible,
1881            abs_path: worktree.abs_path.clone(),
1882        })
1883        .collect::<Vec<_>>();
1884
1885    let add_project_collaborator = proto::AddProjectCollaborator {
1886        project_id: project_id.to_proto(),
1887        collaborator: Some(proto::Collaborator {
1888            peer_id: Some(session.connection_id.into()),
1889            replica_id: replica_id.0 as u32,
1890            user_id: guest_user_id.to_proto(),
1891            is_host: false,
1892        }),
1893    };
1894
1895    for collaborator in &collaborators {
1896        session
1897            .peer
1898            .send(
1899                collaborator.peer_id.unwrap().into(),
1900                add_project_collaborator.clone(),
1901            )
1902            .trace_err();
1903    }
1904
1905    // First, we send the metadata associated with each worktree.
1906    response.send(proto::JoinProjectResponse {
1907        project_id: project.id.0 as u64,
1908        worktrees: worktrees.clone(),
1909        replica_id: replica_id.0 as u32,
1910        collaborators: collaborators.clone(),
1911        language_servers: project.language_servers.clone(),
1912        role: project.role.into(),
1913    })?;
1914
1915    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1916        // Stream this worktree's entries.
1917        let message = proto::UpdateWorktree {
1918            project_id: project_id.to_proto(),
1919            worktree_id,
1920            abs_path: worktree.abs_path.clone(),
1921            root_name: worktree.root_name,
1922            updated_entries: worktree.entries,
1923            removed_entries: Default::default(),
1924            scan_id: worktree.scan_id,
1925            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1926            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1927            removed_repositories: Default::default(),
1928        };
1929        for update in proto::split_worktree_update(message) {
1930            session.peer.send(session.connection_id, update.clone())?;
1931        }
1932
1933        // Stream this worktree's diagnostics.
1934        for summary in worktree.diagnostic_summaries {
1935            session.peer.send(
1936                session.connection_id,
1937                proto::UpdateDiagnosticSummary {
1938                    project_id: project_id.to_proto(),
1939                    worktree_id: worktree.id,
1940                    summary: Some(summary),
1941                },
1942            )?;
1943        }
1944
1945        for settings_file in worktree.settings_files {
1946            session.peer.send(
1947                session.connection_id,
1948                proto::UpdateWorktreeSettings {
1949                    project_id: project_id.to_proto(),
1950                    worktree_id: worktree.id,
1951                    path: settings_file.path,
1952                    content: Some(settings_file.content),
1953                    kind: Some(settings_file.kind.to_proto() as i32),
1954                },
1955            )?;
1956        }
1957    }
1958
1959    for repository in mem::take(&mut project.repositories) {
1960        for update in split_repository_update(repository) {
1961            session.peer.send(session.connection_id, update)?;
1962        }
1963    }
1964
1965    for language_server in &project.language_servers {
1966        session.peer.send(
1967            session.connection_id,
1968            proto::UpdateLanguageServer {
1969                project_id: project_id.to_proto(),
1970                language_server_id: language_server.id,
1971                variant: Some(
1972                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1973                        proto::LspDiskBasedDiagnosticsUpdated {},
1974                    ),
1975                ),
1976            },
1977        )?;
1978    }
1979
1980    Ok(())
1981}
1982
1983/// Leave someone elses shared project.
1984async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1985    let sender_id = session.connection_id;
1986    let project_id = ProjectId::from_proto(request.project_id);
1987    let db = session.db().await;
1988
1989    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1990    tracing::info!(
1991        %project_id,
1992        "leave project"
1993    );
1994
1995    project_left(project, &session);
1996    if let Some(room) = room {
1997        room_updated(room, &session.peer);
1998    }
1999
2000    Ok(())
2001}
2002
2003/// Updates other participants with changes to the project
2004async fn update_project(
2005    request: proto::UpdateProject,
2006    response: Response<proto::UpdateProject>,
2007    session: Session,
2008) -> Result<()> {
2009    let project_id = ProjectId::from_proto(request.project_id);
2010    let (room, guest_connection_ids) = &*session
2011        .db()
2012        .await
2013        .update_project(project_id, session.connection_id, &request.worktrees)
2014        .await?;
2015    broadcast(
2016        Some(session.connection_id),
2017        guest_connection_ids.iter().copied(),
2018        |connection_id| {
2019            session
2020                .peer
2021                .forward_send(session.connection_id, connection_id, request.clone())
2022        },
2023    );
2024    if let Some(room) = room {
2025        room_updated(room, &session.peer);
2026    }
2027    response.send(proto::Ack {})?;
2028
2029    Ok(())
2030}
2031
2032/// Updates other participants with changes to the worktree
2033async fn update_worktree(
2034    request: proto::UpdateWorktree,
2035    response: Response<proto::UpdateWorktree>,
2036    session: Session,
2037) -> Result<()> {
2038    let guest_connection_ids = session
2039        .db()
2040        .await
2041        .update_worktree(&request, session.connection_id)
2042        .await?;
2043
2044    broadcast(
2045        Some(session.connection_id),
2046        guest_connection_ids.iter().copied(),
2047        |connection_id| {
2048            session
2049                .peer
2050                .forward_send(session.connection_id, connection_id, request.clone())
2051        },
2052    );
2053    response.send(proto::Ack {})?;
2054    Ok(())
2055}
2056
2057async fn update_repository(
2058    request: proto::UpdateRepository,
2059    response: Response<proto::UpdateRepository>,
2060    session: Session,
2061) -> Result<()> {
2062    let guest_connection_ids = session
2063        .db()
2064        .await
2065        .update_repository(&request, session.connection_id)
2066        .await?;
2067
2068    broadcast(
2069        Some(session.connection_id),
2070        guest_connection_ids.iter().copied(),
2071        |connection_id| {
2072            session
2073                .peer
2074                .forward_send(session.connection_id, connection_id, request.clone())
2075        },
2076    );
2077    response.send(proto::Ack {})?;
2078    Ok(())
2079}
2080
2081async fn remove_repository(
2082    request: proto::RemoveRepository,
2083    response: Response<proto::RemoveRepository>,
2084    session: Session,
2085) -> Result<()> {
2086    let guest_connection_ids = session
2087        .db()
2088        .await
2089        .remove_repository(&request, session.connection_id)
2090        .await?;
2091
2092    broadcast(
2093        Some(session.connection_id),
2094        guest_connection_ids.iter().copied(),
2095        |connection_id| {
2096            session
2097                .peer
2098                .forward_send(session.connection_id, connection_id, request.clone())
2099        },
2100    );
2101    response.send(proto::Ack {})?;
2102    Ok(())
2103}
2104
2105/// Updates other participants with changes to the diagnostics
2106async fn update_diagnostic_summary(
2107    message: proto::UpdateDiagnosticSummary,
2108    session: Session,
2109) -> Result<()> {
2110    let guest_connection_ids = session
2111        .db()
2112        .await
2113        .update_diagnostic_summary(&message, session.connection_id)
2114        .await?;
2115
2116    broadcast(
2117        Some(session.connection_id),
2118        guest_connection_ids.iter().copied(),
2119        |connection_id| {
2120            session
2121                .peer
2122                .forward_send(session.connection_id, connection_id, message.clone())
2123        },
2124    );
2125
2126    Ok(())
2127}
2128
2129/// Updates other participants with changes to the worktree settings
2130async fn update_worktree_settings(
2131    message: proto::UpdateWorktreeSettings,
2132    session: Session,
2133) -> Result<()> {
2134    let guest_connection_ids = session
2135        .db()
2136        .await
2137        .update_worktree_settings(&message, session.connection_id)
2138        .await?;
2139
2140    broadcast(
2141        Some(session.connection_id),
2142        guest_connection_ids.iter().copied(),
2143        |connection_id| {
2144            session
2145                .peer
2146                .forward_send(session.connection_id, connection_id, message.clone())
2147        },
2148    );
2149
2150    Ok(())
2151}
2152
2153/// Notify other participants that a language server has started.
2154async fn start_language_server(
2155    request: proto::StartLanguageServer,
2156    session: Session,
2157) -> Result<()> {
2158    let guest_connection_ids = session
2159        .db()
2160        .await
2161        .start_language_server(&request, session.connection_id)
2162        .await?;
2163
2164    broadcast(
2165        Some(session.connection_id),
2166        guest_connection_ids.iter().copied(),
2167        |connection_id| {
2168            session
2169                .peer
2170                .forward_send(session.connection_id, connection_id, request.clone())
2171        },
2172    );
2173    Ok(())
2174}
2175
2176/// Notify other participants that a language server has changed.
2177async fn update_language_server(
2178    request: proto::UpdateLanguageServer,
2179    session: Session,
2180) -> Result<()> {
2181    let project_id = ProjectId::from_proto(request.project_id);
2182    let project_connection_ids = session
2183        .db()
2184        .await
2185        .project_connection_ids(project_id, session.connection_id, true)
2186        .await?;
2187    broadcast(
2188        Some(session.connection_id),
2189        project_connection_ids.iter().copied(),
2190        |connection_id| {
2191            session
2192                .peer
2193                .forward_send(session.connection_id, connection_id, request.clone())
2194        },
2195    );
2196    Ok(())
2197}
2198
2199/// forward a project request to the host. These requests should be read only
2200/// as guests are allowed to send them.
2201async fn forward_read_only_project_request<T>(
2202    request: T,
2203    response: Response<T>,
2204    session: Session,
2205) -> Result<()>
2206where
2207    T: EntityMessage + RequestMessage,
2208{
2209    let project_id = ProjectId::from_proto(request.remote_entity_id());
2210    let host_connection_id = session
2211        .db()
2212        .await
2213        .host_for_read_only_project_request(project_id, session.connection_id)
2214        .await?;
2215    let payload = session
2216        .peer
2217        .forward_request(session.connection_id, host_connection_id, request)
2218        .await?;
2219    response.send(payload)?;
2220    Ok(())
2221}
2222
2223async fn forward_find_search_candidates_request(
2224    request: proto::FindSearchCandidates,
2225    response: Response<proto::FindSearchCandidates>,
2226    session: Session,
2227) -> Result<()> {
2228    let project_id = ProjectId::from_proto(request.remote_entity_id());
2229    let host_connection_id = session
2230        .db()
2231        .await
2232        .host_for_read_only_project_request(project_id, session.connection_id)
2233        .await?;
2234    let payload = session
2235        .peer
2236        .forward_request(session.connection_id, host_connection_id, request)
2237        .await?;
2238    response.send(payload)?;
2239    Ok(())
2240}
2241
2242/// forward a project request to the host. These requests are disallowed
2243/// for guests.
2244async fn forward_mutating_project_request<T>(
2245    request: T,
2246    response: Response<T>,
2247    session: Session,
2248) -> Result<()>
2249where
2250    T: EntityMessage + RequestMessage,
2251{
2252    let project_id = ProjectId::from_proto(request.remote_entity_id());
2253
2254    let host_connection_id = session
2255        .db()
2256        .await
2257        .host_for_mutating_project_request(project_id, session.connection_id)
2258        .await?;
2259    let payload = session
2260        .peer
2261        .forward_request(session.connection_id, host_connection_id, request)
2262        .await?;
2263    response.send(payload)?;
2264    Ok(())
2265}
2266
2267/// Notify other participants that a new buffer has been created
2268async fn create_buffer_for_peer(
2269    request: proto::CreateBufferForPeer,
2270    session: Session,
2271) -> Result<()> {
2272    session
2273        .db()
2274        .await
2275        .check_user_is_project_host(
2276            ProjectId::from_proto(request.project_id),
2277            session.connection_id,
2278        )
2279        .await?;
2280    let peer_id = request.peer_id.context("invalid peer id")?;
2281    session
2282        .peer
2283        .forward_send(session.connection_id, peer_id.into(), request)?;
2284    Ok(())
2285}
2286
2287/// Notify other participants that a buffer has been updated. This is
2288/// allowed for guests as long as the update is limited to selections.
2289async fn update_buffer(
2290    request: proto::UpdateBuffer,
2291    response: Response<proto::UpdateBuffer>,
2292    session: Session,
2293) -> Result<()> {
2294    let project_id = ProjectId::from_proto(request.project_id);
2295    let mut capability = Capability::ReadOnly;
2296
2297    for op in request.operations.iter() {
2298        match op.variant {
2299            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2300            Some(_) => capability = Capability::ReadWrite,
2301        }
2302    }
2303
2304    let host = {
2305        let guard = session
2306            .db()
2307            .await
2308            .connections_for_buffer_update(project_id, session.connection_id, capability)
2309            .await?;
2310
2311        let (host, guests) = &*guard;
2312
2313        broadcast(
2314            Some(session.connection_id),
2315            guests.clone(),
2316            |connection_id| {
2317                session
2318                    .peer
2319                    .forward_send(session.connection_id, connection_id, request.clone())
2320            },
2321        );
2322
2323        *host
2324    };
2325
2326    if host != session.connection_id {
2327        session
2328            .peer
2329            .forward_request(session.connection_id, host, request.clone())
2330            .await?;
2331    }
2332
2333    response.send(proto::Ack {})?;
2334    Ok(())
2335}
2336
2337async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2338    let project_id = ProjectId::from_proto(message.project_id);
2339
2340    let operation = message.operation.as_ref().context("invalid operation")?;
2341    let capability = match operation.variant.as_ref() {
2342        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2343            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2344                match buffer_op.variant {
2345                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2346                        Capability::ReadOnly
2347                    }
2348                    _ => Capability::ReadWrite,
2349                }
2350            } else {
2351                Capability::ReadWrite
2352            }
2353        }
2354        Some(_) => Capability::ReadWrite,
2355        None => Capability::ReadOnly,
2356    };
2357
2358    let guard = session
2359        .db()
2360        .await
2361        .connections_for_buffer_update(project_id, session.connection_id, capability)
2362        .await?;
2363
2364    let (host, guests) = &*guard;
2365
2366    broadcast(
2367        Some(session.connection_id),
2368        guests.iter().chain([host]).copied(),
2369        |connection_id| {
2370            session
2371                .peer
2372                .forward_send(session.connection_id, connection_id, message.clone())
2373        },
2374    );
2375
2376    Ok(())
2377}
2378
2379/// Notify other participants that a project has been updated.
2380async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2381    request: T,
2382    session: Session,
2383) -> Result<()> {
2384    let project_id = ProjectId::from_proto(request.remote_entity_id());
2385    let project_connection_ids = session
2386        .db()
2387        .await
2388        .project_connection_ids(project_id, session.connection_id, false)
2389        .await?;
2390
2391    broadcast(
2392        Some(session.connection_id),
2393        project_connection_ids.iter().copied(),
2394        |connection_id| {
2395            session
2396                .peer
2397                .forward_send(session.connection_id, connection_id, request.clone())
2398        },
2399    );
2400    Ok(())
2401}
2402
2403/// Start following another user in a call.
2404async fn follow(
2405    request: proto::Follow,
2406    response: Response<proto::Follow>,
2407    session: Session,
2408) -> Result<()> {
2409    let room_id = RoomId::from_proto(request.room_id);
2410    let project_id = request.project_id.map(ProjectId::from_proto);
2411    let leader_id = request.leader_id.context("invalid leader id")?.into();
2412    let follower_id = session.connection_id;
2413
2414    session
2415        .db()
2416        .await
2417        .check_room_participants(room_id, leader_id, session.connection_id)
2418        .await?;
2419
2420    let response_payload = session
2421        .peer
2422        .forward_request(session.connection_id, leader_id, request)
2423        .await?;
2424    response.send(response_payload)?;
2425
2426    if let Some(project_id) = project_id {
2427        let room = session
2428            .db()
2429            .await
2430            .follow(room_id, project_id, leader_id, follower_id)
2431            .await?;
2432        room_updated(&room, &session.peer);
2433    }
2434
2435    Ok(())
2436}
2437
2438/// Stop following another user in a call.
2439async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2440    let room_id = RoomId::from_proto(request.room_id);
2441    let project_id = request.project_id.map(ProjectId::from_proto);
2442    let leader_id = request.leader_id.context("invalid leader id")?.into();
2443    let follower_id = session.connection_id;
2444
2445    session
2446        .db()
2447        .await
2448        .check_room_participants(room_id, leader_id, session.connection_id)
2449        .await?;
2450
2451    session
2452        .peer
2453        .forward_send(session.connection_id, leader_id, request)?;
2454
2455    if let Some(project_id) = project_id {
2456        let room = session
2457            .db()
2458            .await
2459            .unfollow(room_id, project_id, leader_id, follower_id)
2460            .await?;
2461        room_updated(&room, &session.peer);
2462    }
2463
2464    Ok(())
2465}
2466
2467/// Notify everyone following you of your current location.
2468async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2469    let room_id = RoomId::from_proto(request.room_id);
2470    let database = session.db.lock().await;
2471
2472    let connection_ids = if let Some(project_id) = request.project_id {
2473        let project_id = ProjectId::from_proto(project_id);
2474        database
2475            .project_connection_ids(project_id, session.connection_id, true)
2476            .await?
2477    } else {
2478        database
2479            .room_connection_ids(room_id, session.connection_id)
2480            .await?
2481    };
2482
2483    // For now, don't send view update messages back to that view's current leader.
2484    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2485        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2486        _ => None,
2487    });
2488
2489    for connection_id in connection_ids.iter().cloned() {
2490        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2491            session
2492                .peer
2493                .forward_send(session.connection_id, connection_id, request.clone())?;
2494        }
2495    }
2496    Ok(())
2497}
2498
2499/// Get public data about users.
2500async fn get_users(
2501    request: proto::GetUsers,
2502    response: Response<proto::GetUsers>,
2503    session: Session,
2504) -> Result<()> {
2505    let user_ids = request
2506        .user_ids
2507        .into_iter()
2508        .map(UserId::from_proto)
2509        .collect();
2510    let users = session
2511        .db()
2512        .await
2513        .get_users_by_ids(user_ids)
2514        .await?
2515        .into_iter()
2516        .map(|user| proto::User {
2517            id: user.id.to_proto(),
2518            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2519            github_login: user.github_login,
2520            email: user.email_address,
2521            name: user.name,
2522        })
2523        .collect();
2524    response.send(proto::UsersResponse { users })?;
2525    Ok(())
2526}
2527
2528/// Search for users (to invite) buy Github login
2529async fn fuzzy_search_users(
2530    request: proto::FuzzySearchUsers,
2531    response: Response<proto::FuzzySearchUsers>,
2532    session: Session,
2533) -> Result<()> {
2534    let query = request.query;
2535    let users = match query.len() {
2536        0 => vec![],
2537        1 | 2 => session
2538            .db()
2539            .await
2540            .get_user_by_github_login(&query)
2541            .await?
2542            .into_iter()
2543            .collect(),
2544        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2545    };
2546    let users = users
2547        .into_iter()
2548        .filter(|user| user.id != session.user_id())
2549        .map(|user| proto::User {
2550            id: user.id.to_proto(),
2551            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2552            github_login: user.github_login,
2553            name: user.name,
2554            email: user.email_address,
2555        })
2556        .collect();
2557    response.send(proto::UsersResponse { users })?;
2558    Ok(())
2559}
2560
2561/// Send a contact request to another user.
2562async fn request_contact(
2563    request: proto::RequestContact,
2564    response: Response<proto::RequestContact>,
2565    session: Session,
2566) -> Result<()> {
2567    let requester_id = session.user_id();
2568    let responder_id = UserId::from_proto(request.responder_id);
2569    if requester_id == responder_id {
2570        return Err(anyhow!("cannot add yourself as a contact"))?;
2571    }
2572
2573    let notifications = session
2574        .db()
2575        .await
2576        .send_contact_request(requester_id, responder_id)
2577        .await?;
2578
2579    // Update outgoing contact requests of requester
2580    let mut update = proto::UpdateContacts::default();
2581    update.outgoing_requests.push(responder_id.to_proto());
2582    for connection_id in session
2583        .connection_pool()
2584        .await
2585        .user_connection_ids(requester_id)
2586    {
2587        session.peer.send(connection_id, update.clone())?;
2588    }
2589
2590    // Update incoming contact requests of responder
2591    let mut update = proto::UpdateContacts::default();
2592    update
2593        .incoming_requests
2594        .push(proto::IncomingContactRequest {
2595            requester_id: requester_id.to_proto(),
2596        });
2597    let connection_pool = session.connection_pool().await;
2598    for connection_id in connection_pool.user_connection_ids(responder_id) {
2599        session.peer.send(connection_id, update.clone())?;
2600    }
2601
2602    send_notifications(&connection_pool, &session.peer, notifications);
2603
2604    response.send(proto::Ack {})?;
2605    Ok(())
2606}
2607
2608/// Accept or decline a contact request
2609async fn respond_to_contact_request(
2610    request: proto::RespondToContactRequest,
2611    response: Response<proto::RespondToContactRequest>,
2612    session: Session,
2613) -> Result<()> {
2614    let responder_id = session.user_id();
2615    let requester_id = UserId::from_proto(request.requester_id);
2616    let db = session.db().await;
2617    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2618        db.dismiss_contact_notification(responder_id, requester_id)
2619            .await?;
2620    } else {
2621        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2622
2623        let notifications = db
2624            .respond_to_contact_request(responder_id, requester_id, accept)
2625            .await?;
2626        let requester_busy = db.is_user_busy(requester_id).await?;
2627        let responder_busy = db.is_user_busy(responder_id).await?;
2628
2629        let pool = session.connection_pool().await;
2630        // Update responder with new contact
2631        let mut update = proto::UpdateContacts::default();
2632        if accept {
2633            update
2634                .contacts
2635                .push(contact_for_user(requester_id, requester_busy, &pool));
2636        }
2637        update
2638            .remove_incoming_requests
2639            .push(requester_id.to_proto());
2640        for connection_id in pool.user_connection_ids(responder_id) {
2641            session.peer.send(connection_id, update.clone())?;
2642        }
2643
2644        // Update requester with new contact
2645        let mut update = proto::UpdateContacts::default();
2646        if accept {
2647            update
2648                .contacts
2649                .push(contact_for_user(responder_id, responder_busy, &pool));
2650        }
2651        update
2652            .remove_outgoing_requests
2653            .push(responder_id.to_proto());
2654
2655        for connection_id in pool.user_connection_ids(requester_id) {
2656            session.peer.send(connection_id, update.clone())?;
2657        }
2658
2659        send_notifications(&pool, &session.peer, notifications);
2660    }
2661
2662    response.send(proto::Ack {})?;
2663    Ok(())
2664}
2665
2666/// Remove a contact.
2667async fn remove_contact(
2668    request: proto::RemoveContact,
2669    response: Response<proto::RemoveContact>,
2670    session: Session,
2671) -> Result<()> {
2672    let requester_id = session.user_id();
2673    let responder_id = UserId::from_proto(request.user_id);
2674    let db = session.db().await;
2675    let (contact_accepted, deleted_notification_id) =
2676        db.remove_contact(requester_id, responder_id).await?;
2677
2678    let pool = session.connection_pool().await;
2679    // Update outgoing contact requests of requester
2680    let mut update = proto::UpdateContacts::default();
2681    if contact_accepted {
2682        update.remove_contacts.push(responder_id.to_proto());
2683    } else {
2684        update
2685            .remove_outgoing_requests
2686            .push(responder_id.to_proto());
2687    }
2688    for connection_id in pool.user_connection_ids(requester_id) {
2689        session.peer.send(connection_id, update.clone())?;
2690    }
2691
2692    // Update incoming contact requests of responder
2693    let mut update = proto::UpdateContacts::default();
2694    if contact_accepted {
2695        update.remove_contacts.push(requester_id.to_proto());
2696    } else {
2697        update
2698            .remove_incoming_requests
2699            .push(requester_id.to_proto());
2700    }
2701    for connection_id in pool.user_connection_ids(responder_id) {
2702        session.peer.send(connection_id, update.clone())?;
2703        if let Some(notification_id) = deleted_notification_id {
2704            session.peer.send(
2705                connection_id,
2706                proto::DeleteNotification {
2707                    notification_id: notification_id.to_proto(),
2708                },
2709            )?;
2710        }
2711    }
2712
2713    response.send(proto::Ack {})?;
2714    Ok(())
2715}
2716
2717fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2718    version.0.minor() < 139
2719}
2720
2721async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2722    if is_staff {
2723        return Ok(proto::Plan::ZedPro);
2724    }
2725
2726    let subscription = db.get_active_billing_subscription(user_id).await?;
2727    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2728
2729    let plan = if let Some(subscription_kind) = subscription_kind {
2730        match subscription_kind {
2731            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2732            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2733            SubscriptionKind::ZedFree => proto::Plan::Free,
2734        }
2735    } else {
2736        proto::Plan::Free
2737    };
2738
2739    Ok(plan)
2740}
2741
2742async fn make_update_user_plan_message(
2743    user: &User,
2744    is_staff: bool,
2745    db: &Arc<Database>,
2746    llm_db: Option<Arc<LlmDatabase>>,
2747) -> Result<proto::UpdateUserPlan> {
2748    let feature_flags = db.get_user_flags(user.id).await?;
2749    let plan = current_plan(db, user.id, is_staff).await?;
2750    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2751    let billing_preferences = db.get_billing_preferences(user.id).await?;
2752
2753    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2754        let subscription = db.get_active_billing_subscription(user.id).await?;
2755
2756        let subscription_period =
2757            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2758
2759        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2760            llm_db
2761                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2762                .await?
2763        } else {
2764            None
2765        };
2766
2767        (subscription_period, usage)
2768    } else {
2769        (None, None)
2770    };
2771
2772    let account_too_young =
2773        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2774
2775    Ok(proto::UpdateUserPlan {
2776        plan: plan.into(),
2777        trial_started_at: billing_customer
2778            .as_ref()
2779            .and_then(|billing_customer| billing_customer.trial_started_at)
2780            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2781        is_usage_based_billing_enabled: if is_staff {
2782            Some(true)
2783        } else {
2784            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2785        },
2786        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2787            proto::SubscriptionPeriod {
2788                started_at: started_at.timestamp() as u64,
2789                ended_at: ended_at.timestamp() as u64,
2790            }
2791        }),
2792        account_too_young: Some(account_too_young),
2793        has_overdue_invoices: billing_customer
2794            .map(|billing_customer| billing_customer.has_overdue_invoices),
2795        usage: usage.map(|usage| {
2796            let plan = match plan {
2797                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2798                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2799                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2800            };
2801
2802            let model_requests_limit = match plan.model_requests_limit() {
2803                zed_llm_client::UsageLimit::Limited(limit) => {
2804                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2805                        && feature_flags
2806                            .iter()
2807                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2808                    {
2809                        1_000
2810                    } else {
2811                        limit
2812                    };
2813
2814                    zed_llm_client::UsageLimit::Limited(limit)
2815                }
2816                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2817            };
2818
2819            proto::SubscriptionUsage {
2820                model_requests_usage_amount: usage.model_requests as u32,
2821                model_requests_usage_limit: Some(proto::UsageLimit {
2822                    variant: Some(match model_requests_limit {
2823                        zed_llm_client::UsageLimit::Limited(limit) => {
2824                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2825                                limit: limit as u32,
2826                            })
2827                        }
2828                        zed_llm_client::UsageLimit::Unlimited => {
2829                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2830                        }
2831                    }),
2832                }),
2833                edit_predictions_usage_amount: usage.edit_predictions as u32,
2834                edit_predictions_usage_limit: Some(proto::UsageLimit {
2835                    variant: Some(match plan.edit_predictions_limit() {
2836                        zed_llm_client::UsageLimit::Limited(limit) => {
2837                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2838                                limit: limit as u32,
2839                            })
2840                        }
2841                        zed_llm_client::UsageLimit::Unlimited => {
2842                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2843                        }
2844                    }),
2845                }),
2846            }
2847        }),
2848    })
2849}
2850
2851async fn update_user_plan(session: &Session) -> Result<()> {
2852    let db = session.db().await;
2853
2854    let update_user_plan = make_update_user_plan_message(
2855        session.principal.user(),
2856        session.is_staff(),
2857        &db.0,
2858        session.app_state.llm_db.clone(),
2859    )
2860    .await?;
2861
2862    session
2863        .peer
2864        .send(session.connection_id, update_user_plan)
2865        .trace_err();
2866
2867    Ok(())
2868}
2869
2870async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2871    subscribe_user_to_channels(session.user_id(), &session).await?;
2872    Ok(())
2873}
2874
2875async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2876    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2877    let mut pool = session.connection_pool().await;
2878    for membership in &channels_for_user.channel_memberships {
2879        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2880    }
2881    session.peer.send(
2882        session.connection_id,
2883        build_update_user_channels(&channels_for_user),
2884    )?;
2885    session.peer.send(
2886        session.connection_id,
2887        build_channels_update(channels_for_user),
2888    )?;
2889    Ok(())
2890}
2891
2892/// Creates a new channel.
2893async fn create_channel(
2894    request: proto::CreateChannel,
2895    response: Response<proto::CreateChannel>,
2896    session: Session,
2897) -> Result<()> {
2898    let db = session.db().await;
2899
2900    let parent_id = request.parent_id.map(ChannelId::from_proto);
2901    let (channel, membership) = db
2902        .create_channel(&request.name, parent_id, session.user_id())
2903        .await?;
2904
2905    let root_id = channel.root_id();
2906    let channel = Channel::from_model(channel);
2907
2908    response.send(proto::CreateChannelResponse {
2909        channel: Some(channel.to_proto()),
2910        parent_id: request.parent_id,
2911    })?;
2912
2913    let mut connection_pool = session.connection_pool().await;
2914    if let Some(membership) = membership {
2915        connection_pool.subscribe_to_channel(
2916            membership.user_id,
2917            membership.channel_id,
2918            membership.role,
2919        );
2920        let update = proto::UpdateUserChannels {
2921            channel_memberships: vec![proto::ChannelMembership {
2922                channel_id: membership.channel_id.to_proto(),
2923                role: membership.role.into(),
2924            }],
2925            ..Default::default()
2926        };
2927        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2928            session.peer.send(connection_id, update.clone())?;
2929        }
2930    }
2931
2932    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2933        if !role.can_see_channel(channel.visibility) {
2934            continue;
2935        }
2936
2937        let update = proto::UpdateChannels {
2938            channels: vec![channel.to_proto()],
2939            ..Default::default()
2940        };
2941        session.peer.send(connection_id, update.clone())?;
2942    }
2943
2944    Ok(())
2945}
2946
2947/// Delete a channel
2948async fn delete_channel(
2949    request: proto::DeleteChannel,
2950    response: Response<proto::DeleteChannel>,
2951    session: Session,
2952) -> Result<()> {
2953    let db = session.db().await;
2954
2955    let channel_id = request.channel_id;
2956    let (root_channel, removed_channels) = db
2957        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2958        .await?;
2959    response.send(proto::Ack {})?;
2960
2961    // Notify members of removed channels
2962    let mut update = proto::UpdateChannels::default();
2963    update
2964        .delete_channels
2965        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2966
2967    let connection_pool = session.connection_pool().await;
2968    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2969        session.peer.send(connection_id, update.clone())?;
2970    }
2971
2972    Ok(())
2973}
2974
2975/// Invite someone to join a channel.
2976async fn invite_channel_member(
2977    request: proto::InviteChannelMember,
2978    response: Response<proto::InviteChannelMember>,
2979    session: Session,
2980) -> Result<()> {
2981    let db = session.db().await;
2982    let channel_id = ChannelId::from_proto(request.channel_id);
2983    let invitee_id = UserId::from_proto(request.user_id);
2984    let InviteMemberResult {
2985        channel,
2986        notifications,
2987    } = db
2988        .invite_channel_member(
2989            channel_id,
2990            invitee_id,
2991            session.user_id(),
2992            request.role().into(),
2993        )
2994        .await?;
2995
2996    let update = proto::UpdateChannels {
2997        channel_invitations: vec![channel.to_proto()],
2998        ..Default::default()
2999    };
3000
3001    let connection_pool = session.connection_pool().await;
3002    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3003        session.peer.send(connection_id, update.clone())?;
3004    }
3005
3006    send_notifications(&connection_pool, &session.peer, notifications);
3007
3008    response.send(proto::Ack {})?;
3009    Ok(())
3010}
3011
3012/// remove someone from a channel
3013async fn remove_channel_member(
3014    request: proto::RemoveChannelMember,
3015    response: Response<proto::RemoveChannelMember>,
3016    session: Session,
3017) -> Result<()> {
3018    let db = session.db().await;
3019    let channel_id = ChannelId::from_proto(request.channel_id);
3020    let member_id = UserId::from_proto(request.user_id);
3021
3022    let RemoveChannelMemberResult {
3023        membership_update,
3024        notification_id,
3025    } = db
3026        .remove_channel_member(channel_id, member_id, session.user_id())
3027        .await?;
3028
3029    let mut connection_pool = session.connection_pool().await;
3030    notify_membership_updated(
3031        &mut connection_pool,
3032        membership_update,
3033        member_id,
3034        &session.peer,
3035    );
3036    for connection_id in connection_pool.user_connection_ids(member_id) {
3037        if let Some(notification_id) = notification_id {
3038            session
3039                .peer
3040                .send(
3041                    connection_id,
3042                    proto::DeleteNotification {
3043                        notification_id: notification_id.to_proto(),
3044                    },
3045                )
3046                .trace_err();
3047        }
3048    }
3049
3050    response.send(proto::Ack {})?;
3051    Ok(())
3052}
3053
3054/// Toggle the channel between public and private.
3055/// Care is taken to maintain the invariant that public channels only descend from public channels,
3056/// (though members-only channels can appear at any point in the hierarchy).
3057async fn set_channel_visibility(
3058    request: proto::SetChannelVisibility,
3059    response: Response<proto::SetChannelVisibility>,
3060    session: Session,
3061) -> Result<()> {
3062    let db = session.db().await;
3063    let channel_id = ChannelId::from_proto(request.channel_id);
3064    let visibility = request.visibility().into();
3065
3066    let channel_model = db
3067        .set_channel_visibility(channel_id, visibility, session.user_id())
3068        .await?;
3069    let root_id = channel_model.root_id();
3070    let channel = Channel::from_model(channel_model);
3071
3072    let mut connection_pool = session.connection_pool().await;
3073    for (user_id, role) in connection_pool
3074        .channel_user_ids(root_id)
3075        .collect::<Vec<_>>()
3076        .into_iter()
3077    {
3078        let update = if role.can_see_channel(channel.visibility) {
3079            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3080            proto::UpdateChannels {
3081                channels: vec![channel.to_proto()],
3082                ..Default::default()
3083            }
3084        } else {
3085            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3086            proto::UpdateChannels {
3087                delete_channels: vec![channel.id.to_proto()],
3088                ..Default::default()
3089            }
3090        };
3091
3092        for connection_id in connection_pool.user_connection_ids(user_id) {
3093            session.peer.send(connection_id, update.clone())?;
3094        }
3095    }
3096
3097    response.send(proto::Ack {})?;
3098    Ok(())
3099}
3100
3101/// Alter the role for a user in the channel.
3102async fn set_channel_member_role(
3103    request: proto::SetChannelMemberRole,
3104    response: Response<proto::SetChannelMemberRole>,
3105    session: Session,
3106) -> Result<()> {
3107    let db = session.db().await;
3108    let channel_id = ChannelId::from_proto(request.channel_id);
3109    let member_id = UserId::from_proto(request.user_id);
3110    let result = db
3111        .set_channel_member_role(
3112            channel_id,
3113            session.user_id(),
3114            member_id,
3115            request.role().into(),
3116        )
3117        .await?;
3118
3119    match result {
3120        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3121            let mut connection_pool = session.connection_pool().await;
3122            notify_membership_updated(
3123                &mut connection_pool,
3124                membership_update,
3125                member_id,
3126                &session.peer,
3127            )
3128        }
3129        db::SetMemberRoleResult::InviteUpdated(channel) => {
3130            let update = proto::UpdateChannels {
3131                channel_invitations: vec![channel.to_proto()],
3132                ..Default::default()
3133            };
3134
3135            for connection_id in session
3136                .connection_pool()
3137                .await
3138                .user_connection_ids(member_id)
3139            {
3140                session.peer.send(connection_id, update.clone())?;
3141            }
3142        }
3143    }
3144
3145    response.send(proto::Ack {})?;
3146    Ok(())
3147}
3148
3149/// Change the name of a channel
3150async fn rename_channel(
3151    request: proto::RenameChannel,
3152    response: Response<proto::RenameChannel>,
3153    session: Session,
3154) -> Result<()> {
3155    let db = session.db().await;
3156    let channel_id = ChannelId::from_proto(request.channel_id);
3157    let channel_model = db
3158        .rename_channel(channel_id, session.user_id(), &request.name)
3159        .await?;
3160    let root_id = channel_model.root_id();
3161    let channel = Channel::from_model(channel_model);
3162
3163    response.send(proto::RenameChannelResponse {
3164        channel: Some(channel.to_proto()),
3165    })?;
3166
3167    let connection_pool = session.connection_pool().await;
3168    let update = proto::UpdateChannels {
3169        channels: vec![channel.to_proto()],
3170        ..Default::default()
3171    };
3172    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3173        if role.can_see_channel(channel.visibility) {
3174            session.peer.send(connection_id, update.clone())?;
3175        }
3176    }
3177
3178    Ok(())
3179}
3180
3181/// Move a channel to a new parent.
3182async fn move_channel(
3183    request: proto::MoveChannel,
3184    response: Response<proto::MoveChannel>,
3185    session: Session,
3186) -> Result<()> {
3187    let channel_id = ChannelId::from_proto(request.channel_id);
3188    let to = ChannelId::from_proto(request.to);
3189
3190    let (root_id, channels) = session
3191        .db()
3192        .await
3193        .move_channel(channel_id, to, session.user_id())
3194        .await?;
3195
3196    let connection_pool = session.connection_pool().await;
3197    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3198        let channels = channels
3199            .iter()
3200            .filter_map(|channel| {
3201                if role.can_see_channel(channel.visibility) {
3202                    Some(channel.to_proto())
3203                } else {
3204                    None
3205                }
3206            })
3207            .collect::<Vec<_>>();
3208        if channels.is_empty() {
3209            continue;
3210        }
3211
3212        let update = proto::UpdateChannels {
3213            channels,
3214            ..Default::default()
3215        };
3216
3217        session.peer.send(connection_id, update.clone())?;
3218    }
3219
3220    response.send(Ack {})?;
3221    Ok(())
3222}
3223
3224async fn reorder_channel(
3225    request: proto::ReorderChannel,
3226    response: Response<proto::ReorderChannel>,
3227    session: Session,
3228) -> Result<()> {
3229    let channel_id = ChannelId::from_proto(request.channel_id);
3230    let direction = request.direction();
3231
3232    let updated_channels = session
3233        .db()
3234        .await
3235        .reorder_channel(channel_id, direction, session.user_id())
3236        .await?;
3237
3238    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3239        let connection_pool = session.connection_pool().await;
3240        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3241            let channels = updated_channels
3242                .iter()
3243                .filter_map(|channel| {
3244                    if role.can_see_channel(channel.visibility) {
3245                        Some(channel.to_proto())
3246                    } else {
3247                        None
3248                    }
3249                })
3250                .collect::<Vec<_>>();
3251
3252            if channels.is_empty() {
3253                continue;
3254            }
3255
3256            let update = proto::UpdateChannels {
3257                channels,
3258                ..Default::default()
3259            };
3260
3261            session.peer.send(connection_id, update.clone())?;
3262        }
3263    }
3264
3265    response.send(Ack {})?;
3266    Ok(())
3267}
3268
3269/// Get the list of channel members
3270async fn get_channel_members(
3271    request: proto::GetChannelMembers,
3272    response: Response<proto::GetChannelMembers>,
3273    session: Session,
3274) -> Result<()> {
3275    let db = session.db().await;
3276    let channel_id = ChannelId::from_proto(request.channel_id);
3277    let limit = if request.limit == 0 {
3278        u16::MAX as u64
3279    } else {
3280        request.limit
3281    };
3282    let (members, users) = db
3283        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3284        .await?;
3285    response.send(proto::GetChannelMembersResponse { members, users })?;
3286    Ok(())
3287}
3288
3289/// Accept or decline a channel invitation.
3290async fn respond_to_channel_invite(
3291    request: proto::RespondToChannelInvite,
3292    response: Response<proto::RespondToChannelInvite>,
3293    session: Session,
3294) -> Result<()> {
3295    let db = session.db().await;
3296    let channel_id = ChannelId::from_proto(request.channel_id);
3297    let RespondToChannelInvite {
3298        membership_update,
3299        notifications,
3300    } = db
3301        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3302        .await?;
3303
3304    let mut connection_pool = session.connection_pool().await;
3305    if let Some(membership_update) = membership_update {
3306        notify_membership_updated(
3307            &mut connection_pool,
3308            membership_update,
3309            session.user_id(),
3310            &session.peer,
3311        );
3312    } else {
3313        let update = proto::UpdateChannels {
3314            remove_channel_invitations: vec![channel_id.to_proto()],
3315            ..Default::default()
3316        };
3317
3318        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3319            session.peer.send(connection_id, update.clone())?;
3320        }
3321    };
3322
3323    send_notifications(&connection_pool, &session.peer, notifications);
3324
3325    response.send(proto::Ack {})?;
3326
3327    Ok(())
3328}
3329
3330/// Join the channels' room
3331async fn join_channel(
3332    request: proto::JoinChannel,
3333    response: Response<proto::JoinChannel>,
3334    session: Session,
3335) -> Result<()> {
3336    let channel_id = ChannelId::from_proto(request.channel_id);
3337    join_channel_internal(channel_id, Box::new(response), session).await
3338}
3339
3340trait JoinChannelInternalResponse {
3341    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3342}
3343impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3344    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3345        Response::<proto::JoinChannel>::send(self, result)
3346    }
3347}
3348impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3349    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3350        Response::<proto::JoinRoom>::send(self, result)
3351    }
3352}
3353
3354async fn join_channel_internal(
3355    channel_id: ChannelId,
3356    response: Box<impl JoinChannelInternalResponse>,
3357    session: Session,
3358) -> Result<()> {
3359    let joined_room = {
3360        let mut db = session.db().await;
3361        // If zed quits without leaving the room, and the user re-opens zed before the
3362        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3363        // room they were in.
3364        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3365            tracing::info!(
3366                stale_connection_id = %connection,
3367                "cleaning up stale connection",
3368            );
3369            drop(db);
3370            leave_room_for_session(&session, connection).await?;
3371            db = session.db().await;
3372        }
3373
3374        let (joined_room, membership_updated, role) = db
3375            .join_channel(channel_id, session.user_id(), session.connection_id)
3376            .await?;
3377
3378        let live_kit_connection_info =
3379            session
3380                .app_state
3381                .livekit_client
3382                .as_ref()
3383                .and_then(|live_kit| {
3384                    let (can_publish, token) = if role == ChannelRole::Guest {
3385                        (
3386                            false,
3387                            live_kit
3388                                .guest_token(
3389                                    &joined_room.room.livekit_room,
3390                                    &session.user_id().to_string(),
3391                                )
3392                                .trace_err()?,
3393                        )
3394                    } else {
3395                        (
3396                            true,
3397                            live_kit
3398                                .room_token(
3399                                    &joined_room.room.livekit_room,
3400                                    &session.user_id().to_string(),
3401                                )
3402                                .trace_err()?,
3403                        )
3404                    };
3405
3406                    Some(LiveKitConnectionInfo {
3407                        server_url: live_kit.url().into(),
3408                        token,
3409                        can_publish,
3410                    })
3411                });
3412
3413        response.send(proto::JoinRoomResponse {
3414            room: Some(joined_room.room.clone()),
3415            channel_id: joined_room
3416                .channel
3417                .as_ref()
3418                .map(|channel| channel.id.to_proto()),
3419            live_kit_connection_info,
3420        })?;
3421
3422        let mut connection_pool = session.connection_pool().await;
3423        if let Some(membership_updated) = membership_updated {
3424            notify_membership_updated(
3425                &mut connection_pool,
3426                membership_updated,
3427                session.user_id(),
3428                &session.peer,
3429            );
3430        }
3431
3432        room_updated(&joined_room.room, &session.peer);
3433
3434        joined_room
3435    };
3436
3437    channel_updated(
3438        &joined_room.channel.context("channel not returned")?,
3439        &joined_room.room,
3440        &session.peer,
3441        &*session.connection_pool().await,
3442    );
3443
3444    update_user_contacts(session.user_id(), &session).await?;
3445    Ok(())
3446}
3447
3448/// Start editing the channel notes
3449async fn join_channel_buffer(
3450    request: proto::JoinChannelBuffer,
3451    response: Response<proto::JoinChannelBuffer>,
3452    session: Session,
3453) -> Result<()> {
3454    let db = session.db().await;
3455    let channel_id = ChannelId::from_proto(request.channel_id);
3456
3457    let open_response = db
3458        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3459        .await?;
3460
3461    let collaborators = open_response.collaborators.clone();
3462    response.send(open_response)?;
3463
3464    let update = UpdateChannelBufferCollaborators {
3465        channel_id: channel_id.to_proto(),
3466        collaborators: collaborators.clone(),
3467    };
3468    channel_buffer_updated(
3469        session.connection_id,
3470        collaborators
3471            .iter()
3472            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3473        &update,
3474        &session.peer,
3475    );
3476
3477    Ok(())
3478}
3479
3480/// Edit the channel notes
3481async fn update_channel_buffer(
3482    request: proto::UpdateChannelBuffer,
3483    session: Session,
3484) -> Result<()> {
3485    let db = session.db().await;
3486    let channel_id = ChannelId::from_proto(request.channel_id);
3487
3488    let (collaborators, epoch, version) = db
3489        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3490        .await?;
3491
3492    channel_buffer_updated(
3493        session.connection_id,
3494        collaborators.clone(),
3495        &proto::UpdateChannelBuffer {
3496            channel_id: channel_id.to_proto(),
3497            operations: request.operations,
3498        },
3499        &session.peer,
3500    );
3501
3502    let pool = &*session.connection_pool().await;
3503
3504    let non_collaborators =
3505        pool.channel_connection_ids(channel_id)
3506            .filter_map(|(connection_id, _)| {
3507                if collaborators.contains(&connection_id) {
3508                    None
3509                } else {
3510                    Some(connection_id)
3511                }
3512            });
3513
3514    broadcast(None, non_collaborators, |peer_id| {
3515        session.peer.send(
3516            peer_id,
3517            proto::UpdateChannels {
3518                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3519                    channel_id: channel_id.to_proto(),
3520                    epoch: epoch as u64,
3521                    version: version.clone(),
3522                }],
3523                ..Default::default()
3524            },
3525        )
3526    });
3527
3528    Ok(())
3529}
3530
3531/// Rejoin the channel notes after a connection blip
3532async fn rejoin_channel_buffers(
3533    request: proto::RejoinChannelBuffers,
3534    response: Response<proto::RejoinChannelBuffers>,
3535    session: Session,
3536) -> Result<()> {
3537    let db = session.db().await;
3538    let buffers = db
3539        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3540        .await?;
3541
3542    for rejoined_buffer in &buffers {
3543        let collaborators_to_notify = rejoined_buffer
3544            .buffer
3545            .collaborators
3546            .iter()
3547            .filter_map(|c| Some(c.peer_id?.into()));
3548        channel_buffer_updated(
3549            session.connection_id,
3550            collaborators_to_notify,
3551            &proto::UpdateChannelBufferCollaborators {
3552                channel_id: rejoined_buffer.buffer.channel_id,
3553                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3554            },
3555            &session.peer,
3556        );
3557    }
3558
3559    response.send(proto::RejoinChannelBuffersResponse {
3560        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3561    })?;
3562
3563    Ok(())
3564}
3565
3566/// Stop editing the channel notes
3567async fn leave_channel_buffer(
3568    request: proto::LeaveChannelBuffer,
3569    response: Response<proto::LeaveChannelBuffer>,
3570    session: Session,
3571) -> Result<()> {
3572    let db = session.db().await;
3573    let channel_id = ChannelId::from_proto(request.channel_id);
3574
3575    let left_buffer = db
3576        .leave_channel_buffer(channel_id, session.connection_id)
3577        .await?;
3578
3579    response.send(Ack {})?;
3580
3581    channel_buffer_updated(
3582        session.connection_id,
3583        left_buffer.connections,
3584        &proto::UpdateChannelBufferCollaborators {
3585            channel_id: channel_id.to_proto(),
3586            collaborators: left_buffer.collaborators,
3587        },
3588        &session.peer,
3589    );
3590
3591    Ok(())
3592}
3593
3594fn channel_buffer_updated<T: EnvelopedMessage>(
3595    sender_id: ConnectionId,
3596    collaborators: impl IntoIterator<Item = ConnectionId>,
3597    message: &T,
3598    peer: &Peer,
3599) {
3600    broadcast(Some(sender_id), collaborators, |peer_id| {
3601        peer.send(peer_id, message.clone())
3602    });
3603}
3604
3605fn send_notifications(
3606    connection_pool: &ConnectionPool,
3607    peer: &Peer,
3608    notifications: db::NotificationBatch,
3609) {
3610    for (user_id, notification) in notifications {
3611        for connection_id in connection_pool.user_connection_ids(user_id) {
3612            if let Err(error) = peer.send(
3613                connection_id,
3614                proto::AddNotification {
3615                    notification: Some(notification.clone()),
3616                },
3617            ) {
3618                tracing::error!(
3619                    "failed to send notification to {:?} {}",
3620                    connection_id,
3621                    error
3622                );
3623            }
3624        }
3625    }
3626}
3627
3628/// Send a message to the channel
3629async fn send_channel_message(
3630    request: proto::SendChannelMessage,
3631    response: Response<proto::SendChannelMessage>,
3632    session: Session,
3633) -> Result<()> {
3634    // Validate the message body.
3635    let body = request.body.trim().to_string();
3636    if body.len() > MAX_MESSAGE_LEN {
3637        return Err(anyhow!("message is too long"))?;
3638    }
3639    if body.is_empty() {
3640        return Err(anyhow!("message can't be blank"))?;
3641    }
3642
3643    // TODO: adjust mentions if body is trimmed
3644
3645    let timestamp = OffsetDateTime::now_utc();
3646    let nonce = request.nonce.context("nonce can't be blank")?;
3647
3648    let channel_id = ChannelId::from_proto(request.channel_id);
3649    let CreatedChannelMessage {
3650        message_id,
3651        participant_connection_ids,
3652        notifications,
3653    } = session
3654        .db()
3655        .await
3656        .create_channel_message(
3657            channel_id,
3658            session.user_id(),
3659            &body,
3660            &request.mentions,
3661            timestamp,
3662            nonce.clone().into(),
3663            request.reply_to_message_id.map(MessageId::from_proto),
3664        )
3665        .await?;
3666
3667    let message = proto::ChannelMessage {
3668        sender_id: session.user_id().to_proto(),
3669        id: message_id.to_proto(),
3670        body,
3671        mentions: request.mentions,
3672        timestamp: timestamp.unix_timestamp() as u64,
3673        nonce: Some(nonce),
3674        reply_to_message_id: request.reply_to_message_id,
3675        edited_at: None,
3676    };
3677    broadcast(
3678        Some(session.connection_id),
3679        participant_connection_ids.clone(),
3680        |connection| {
3681            session.peer.send(
3682                connection,
3683                proto::ChannelMessageSent {
3684                    channel_id: channel_id.to_proto(),
3685                    message: Some(message.clone()),
3686                },
3687            )
3688        },
3689    );
3690    response.send(proto::SendChannelMessageResponse {
3691        message: Some(message),
3692    })?;
3693
3694    let pool = &*session.connection_pool().await;
3695    let non_participants =
3696        pool.channel_connection_ids(channel_id)
3697            .filter_map(|(connection_id, _)| {
3698                if participant_connection_ids.contains(&connection_id) {
3699                    None
3700                } else {
3701                    Some(connection_id)
3702                }
3703            });
3704    broadcast(None, non_participants, |peer_id| {
3705        session.peer.send(
3706            peer_id,
3707            proto::UpdateChannels {
3708                latest_channel_message_ids: vec![proto::ChannelMessageId {
3709                    channel_id: channel_id.to_proto(),
3710                    message_id: message_id.to_proto(),
3711                }],
3712                ..Default::default()
3713            },
3714        )
3715    });
3716    send_notifications(pool, &session.peer, notifications);
3717
3718    Ok(())
3719}
3720
3721/// Delete a channel message
3722async fn remove_channel_message(
3723    request: proto::RemoveChannelMessage,
3724    response: Response<proto::RemoveChannelMessage>,
3725    session: Session,
3726) -> Result<()> {
3727    let channel_id = ChannelId::from_proto(request.channel_id);
3728    let message_id = MessageId::from_proto(request.message_id);
3729    let (connection_ids, existing_notification_ids) = session
3730        .db()
3731        .await
3732        .remove_channel_message(channel_id, message_id, session.user_id())
3733        .await?;
3734
3735    broadcast(
3736        Some(session.connection_id),
3737        connection_ids,
3738        move |connection| {
3739            session.peer.send(connection, request.clone())?;
3740
3741            for notification_id in &existing_notification_ids {
3742                session.peer.send(
3743                    connection,
3744                    proto::DeleteNotification {
3745                        notification_id: (*notification_id).to_proto(),
3746                    },
3747                )?;
3748            }
3749
3750            Ok(())
3751        },
3752    );
3753    response.send(proto::Ack {})?;
3754    Ok(())
3755}
3756
3757async fn update_channel_message(
3758    request: proto::UpdateChannelMessage,
3759    response: Response<proto::UpdateChannelMessage>,
3760    session: Session,
3761) -> Result<()> {
3762    let channel_id = ChannelId::from_proto(request.channel_id);
3763    let message_id = MessageId::from_proto(request.message_id);
3764    let updated_at = OffsetDateTime::now_utc();
3765    let UpdatedChannelMessage {
3766        message_id,
3767        participant_connection_ids,
3768        notifications,
3769        reply_to_message_id,
3770        timestamp,
3771        deleted_mention_notification_ids,
3772        updated_mention_notifications,
3773    } = session
3774        .db()
3775        .await
3776        .update_channel_message(
3777            channel_id,
3778            message_id,
3779            session.user_id(),
3780            request.body.as_str(),
3781            &request.mentions,
3782            updated_at,
3783        )
3784        .await?;
3785
3786    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3787
3788    let message = proto::ChannelMessage {
3789        sender_id: session.user_id().to_proto(),
3790        id: message_id.to_proto(),
3791        body: request.body.clone(),
3792        mentions: request.mentions.clone(),
3793        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3794        nonce: Some(nonce),
3795        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3796        edited_at: Some(updated_at.unix_timestamp() as u64),
3797    };
3798
3799    response.send(proto::Ack {})?;
3800
3801    let pool = &*session.connection_pool().await;
3802    broadcast(
3803        Some(session.connection_id),
3804        participant_connection_ids,
3805        |connection| {
3806            session.peer.send(
3807                connection,
3808                proto::ChannelMessageUpdate {
3809                    channel_id: channel_id.to_proto(),
3810                    message: Some(message.clone()),
3811                },
3812            )?;
3813
3814            for notification_id in &deleted_mention_notification_ids {
3815                session.peer.send(
3816                    connection,
3817                    proto::DeleteNotification {
3818                        notification_id: (*notification_id).to_proto(),
3819                    },
3820                )?;
3821            }
3822
3823            for notification in &updated_mention_notifications {
3824                session.peer.send(
3825                    connection,
3826                    proto::UpdateNotification {
3827                        notification: Some(notification.clone()),
3828                    },
3829                )?;
3830            }
3831
3832            Ok(())
3833        },
3834    );
3835
3836    send_notifications(pool, &session.peer, notifications);
3837
3838    Ok(())
3839}
3840
3841/// Mark a channel message as read
3842async fn acknowledge_channel_message(
3843    request: proto::AckChannelMessage,
3844    session: Session,
3845) -> Result<()> {
3846    let channel_id = ChannelId::from_proto(request.channel_id);
3847    let message_id = MessageId::from_proto(request.message_id);
3848    let notifications = session
3849        .db()
3850        .await
3851        .observe_channel_message(channel_id, session.user_id(), message_id)
3852        .await?;
3853    send_notifications(
3854        &*session.connection_pool().await,
3855        &session.peer,
3856        notifications,
3857    );
3858    Ok(())
3859}
3860
3861/// Mark a buffer version as synced
3862async fn acknowledge_buffer_version(
3863    request: proto::AckBufferOperation,
3864    session: Session,
3865) -> Result<()> {
3866    let buffer_id = BufferId::from_proto(request.buffer_id);
3867    session
3868        .db()
3869        .await
3870        .observe_buffer_version(
3871            buffer_id,
3872            session.user_id(),
3873            request.epoch as i32,
3874            &request.version,
3875        )
3876        .await?;
3877    Ok(())
3878}
3879
3880/// Get a Supermaven API key for the user
3881async fn get_supermaven_api_key(
3882    _request: proto::GetSupermavenApiKey,
3883    response: Response<proto::GetSupermavenApiKey>,
3884    session: Session,
3885) -> Result<()> {
3886    let user_id: String = session.user_id().to_string();
3887    if !session.is_staff() {
3888        return Err(anyhow!("supermaven not enabled for this account"))?;
3889    }
3890
3891    let email = session.email().context("user must have an email")?;
3892
3893    let supermaven_admin_api = session
3894        .supermaven_client
3895        .as_ref()
3896        .context("supermaven not configured")?;
3897
3898    let result = supermaven_admin_api
3899        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3900        .await?;
3901
3902    response.send(proto::GetSupermavenApiKeyResponse {
3903        api_key: result.api_key,
3904    })?;
3905
3906    Ok(())
3907}
3908
3909/// Start receiving chat updates for a channel
3910async fn join_channel_chat(
3911    request: proto::JoinChannelChat,
3912    response: Response<proto::JoinChannelChat>,
3913    session: Session,
3914) -> Result<()> {
3915    let channel_id = ChannelId::from_proto(request.channel_id);
3916
3917    let db = session.db().await;
3918    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3919        .await?;
3920    let messages = db
3921        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3922        .await?;
3923    response.send(proto::JoinChannelChatResponse {
3924        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925        messages,
3926    })?;
3927    Ok(())
3928}
3929
3930/// Stop receiving chat updates for a channel
3931async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3932    let channel_id = ChannelId::from_proto(request.channel_id);
3933    session
3934        .db()
3935        .await
3936        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3937        .await?;
3938    Ok(())
3939}
3940
3941/// Retrieve the chat history for a channel
3942async fn get_channel_messages(
3943    request: proto::GetChannelMessages,
3944    response: Response<proto::GetChannelMessages>,
3945    session: Session,
3946) -> Result<()> {
3947    let channel_id = ChannelId::from_proto(request.channel_id);
3948    let messages = session
3949        .db()
3950        .await
3951        .get_channel_messages(
3952            channel_id,
3953            session.user_id(),
3954            MESSAGE_COUNT_PER_PAGE,
3955            Some(MessageId::from_proto(request.before_message_id)),
3956        )
3957        .await?;
3958    response.send(proto::GetChannelMessagesResponse {
3959        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3960        messages,
3961    })?;
3962    Ok(())
3963}
3964
3965/// Retrieve specific chat messages
3966async fn get_channel_messages_by_id(
3967    request: proto::GetChannelMessagesById,
3968    response: Response<proto::GetChannelMessagesById>,
3969    session: Session,
3970) -> Result<()> {
3971    let message_ids = request
3972        .message_ids
3973        .iter()
3974        .map(|id| MessageId::from_proto(*id))
3975        .collect::<Vec<_>>();
3976    let messages = session
3977        .db()
3978        .await
3979        .get_channel_messages_by_id(session.user_id(), &message_ids)
3980        .await?;
3981    response.send(proto::GetChannelMessagesResponse {
3982        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3983        messages,
3984    })?;
3985    Ok(())
3986}
3987
3988/// Retrieve the current users notifications
3989async fn get_notifications(
3990    request: proto::GetNotifications,
3991    response: Response<proto::GetNotifications>,
3992    session: Session,
3993) -> Result<()> {
3994    let notifications = session
3995        .db()
3996        .await
3997        .get_notifications(
3998            session.user_id(),
3999            NOTIFICATION_COUNT_PER_PAGE,
4000            request.before_id.map(db::NotificationId::from_proto),
4001        )
4002        .await?;
4003    response.send(proto::GetNotificationsResponse {
4004        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4005        notifications,
4006    })?;
4007    Ok(())
4008}
4009
4010/// Mark notifications as read
4011async fn mark_notification_as_read(
4012    request: proto::MarkNotificationRead,
4013    response: Response<proto::MarkNotificationRead>,
4014    session: Session,
4015) -> Result<()> {
4016    let database = &session.db().await;
4017    let notifications = database
4018        .mark_notification_as_read_by_id(
4019            session.user_id(),
4020            NotificationId::from_proto(request.notification_id),
4021        )
4022        .await?;
4023    send_notifications(
4024        &*session.connection_pool().await,
4025        &session.peer,
4026        notifications,
4027    );
4028    response.send(proto::Ack {})?;
4029    Ok(())
4030}
4031
4032/// Get the current users information
4033async fn get_private_user_info(
4034    _request: proto::GetPrivateUserInfo,
4035    response: Response<proto::GetPrivateUserInfo>,
4036    session: Session,
4037) -> Result<()> {
4038    let db = session.db().await;
4039
4040    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4041    let user = db
4042        .get_user_by_id(session.user_id())
4043        .await?
4044        .context("user not found")?;
4045    let flags = db.get_user_flags(session.user_id()).await?;
4046
4047    response.send(proto::GetPrivateUserInfoResponse {
4048        metrics_id,
4049        staff: user.admin,
4050        flags,
4051        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4052    })?;
4053    Ok(())
4054}
4055
4056/// Accept the terms of service (tos) on behalf of the current user
4057async fn accept_terms_of_service(
4058    _request: proto::AcceptTermsOfService,
4059    response: Response<proto::AcceptTermsOfService>,
4060    session: Session,
4061) -> Result<()> {
4062    let db = session.db().await;
4063
4064    let accepted_tos_at = Utc::now();
4065    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4066        .await?;
4067
4068    response.send(proto::AcceptTermsOfServiceResponse {
4069        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4070    })?;
4071    Ok(())
4072}
4073
4074/// The minimum account age an account must have in order to use the LLM service.
4075pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4076
4077async fn get_llm_api_token(
4078    _request: proto::GetLlmToken,
4079    response: Response<proto::GetLlmToken>,
4080    session: Session,
4081) -> Result<()> {
4082    let db = session.db().await;
4083
4084    let flags = db.get_user_flags(session.user_id()).await?;
4085
4086    let user_id = session.user_id();
4087    let user = db
4088        .get_user_by_id(user_id)
4089        .await?
4090        .with_context(|| format!("user {user_id} not found"))?;
4091
4092    if user.accepted_tos_at.is_none() {
4093        Err(anyhow!("terms of service not accepted"))?
4094    }
4095
4096    let stripe_client = session
4097        .app_state
4098        .stripe_client
4099        .as_ref()
4100        .context("failed to retrieve Stripe client")?;
4101
4102    let stripe_billing = session
4103        .app_state
4104        .stripe_billing
4105        .as_ref()
4106        .context("failed to retrieve Stripe billing object")?;
4107
4108    let billing_customer = if let Some(billing_customer) =
4109        db.get_billing_customer_by_user_id(user.id).await?
4110    {
4111        billing_customer
4112    } else {
4113        let customer_id = stripe_billing
4114            .find_or_create_customer_by_email(user.email_address.as_deref())
4115            .await?;
4116
4117        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4118            .await?
4119            .context("billing customer not found")?
4120    };
4121
4122    let billing_subscription =
4123        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4124            billing_subscription
4125        } else {
4126            let stripe_customer_id =
4127                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4128
4129            let stripe_subscription = stripe_billing
4130                .subscribe_to_zed_free(stripe_customer_id)
4131                .await?;
4132
4133            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4134                billing_customer_id: billing_customer.id,
4135                kind: Some(SubscriptionKind::ZedFree),
4136                stripe_subscription_id: stripe_subscription.id.to_string(),
4137                stripe_subscription_status: stripe_subscription.status.into(),
4138                stripe_cancellation_reason: None,
4139                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4140                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4141            })
4142            .await?
4143        };
4144
4145    let billing_preferences = db.get_billing_preferences(user.id).await?;
4146
4147    let token = LlmTokenClaims::create(
4148        &user,
4149        session.is_staff(),
4150        billing_customer,
4151        billing_preferences,
4152        &flags,
4153        billing_subscription,
4154        session.system_id.clone(),
4155        &session.app_state.config,
4156    )?;
4157    response.send(proto::GetLlmTokenResponse { token })?;
4158    Ok(())
4159}
4160
4161fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4162    let message = match message {
4163        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4164        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4165        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4166        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4167        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4168            code: frame.code.into(),
4169            reason: frame.reason.as_str().to_owned().into(),
4170        })),
4171        // We should never receive a frame while reading the message, according
4172        // to the `tungstenite` maintainers:
4173        //
4174        // > It cannot occur when you read messages from the WebSocket, but it
4175        // > can be used when you want to send the raw frames (e.g. you want to
4176        // > send the frames to the WebSocket without composing the full message first).
4177        // >
4178        // > — https://github.com/snapview/tungstenite-rs/issues/268
4179        TungsteniteMessage::Frame(_) => {
4180            bail!("received an unexpected frame while reading the message")
4181        }
4182    };
4183
4184    Ok(message)
4185}
4186
4187fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4188    match message {
4189        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4190        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4191        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4192        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4193        AxumMessage::Close(frame) => {
4194            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4195                code: frame.code.into(),
4196                reason: frame.reason.as_ref().into(),
4197            }))
4198        }
4199    }
4200}
4201
4202fn notify_membership_updated(
4203    connection_pool: &mut ConnectionPool,
4204    result: MembershipUpdated,
4205    user_id: UserId,
4206    peer: &Peer,
4207) {
4208    for membership in &result.new_channels.channel_memberships {
4209        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4210    }
4211    for channel_id in &result.removed_channels {
4212        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4213    }
4214
4215    let user_channels_update = proto::UpdateUserChannels {
4216        channel_memberships: result
4217            .new_channels
4218            .channel_memberships
4219            .iter()
4220            .map(|cm| proto::ChannelMembership {
4221                channel_id: cm.channel_id.to_proto(),
4222                role: cm.role.into(),
4223            })
4224            .collect(),
4225        ..Default::default()
4226    };
4227
4228    let mut update = build_channels_update(result.new_channels);
4229    update.delete_channels = result
4230        .removed_channels
4231        .into_iter()
4232        .map(|id| id.to_proto())
4233        .collect();
4234    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4235
4236    for connection_id in connection_pool.user_connection_ids(user_id) {
4237        peer.send(connection_id, user_channels_update.clone())
4238            .trace_err();
4239        peer.send(connection_id, update.clone()).trace_err();
4240    }
4241}
4242
4243fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4244    proto::UpdateUserChannels {
4245        channel_memberships: channels
4246            .channel_memberships
4247            .iter()
4248            .map(|m| proto::ChannelMembership {
4249                channel_id: m.channel_id.to_proto(),
4250                role: m.role.into(),
4251            })
4252            .collect(),
4253        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4254        observed_channel_message_id: channels.observed_channel_messages.clone(),
4255    }
4256}
4257
4258fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4259    let mut update = proto::UpdateChannels::default();
4260
4261    for channel in channels.channels {
4262        update.channels.push(channel.to_proto());
4263    }
4264
4265    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4266    update.latest_channel_message_ids = channels.latest_channel_messages;
4267
4268    for (channel_id, participants) in channels.channel_participants {
4269        update
4270            .channel_participants
4271            .push(proto::ChannelParticipants {
4272                channel_id: channel_id.to_proto(),
4273                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4274            });
4275    }
4276
4277    for channel in channels.invited_channels {
4278        update.channel_invitations.push(channel.to_proto());
4279    }
4280
4281    update
4282}
4283
4284fn build_initial_contacts_update(
4285    contacts: Vec<db::Contact>,
4286    pool: &ConnectionPool,
4287) -> proto::UpdateContacts {
4288    let mut update = proto::UpdateContacts::default();
4289
4290    for contact in contacts {
4291        match contact {
4292            db::Contact::Accepted { user_id, busy } => {
4293                update.contacts.push(contact_for_user(user_id, busy, pool));
4294            }
4295            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4296            db::Contact::Incoming { user_id } => {
4297                update
4298                    .incoming_requests
4299                    .push(proto::IncomingContactRequest {
4300                        requester_id: user_id.to_proto(),
4301                    })
4302            }
4303        }
4304    }
4305
4306    update
4307}
4308
4309fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4310    proto::Contact {
4311        user_id: user_id.to_proto(),
4312        online: pool.is_user_online(user_id),
4313        busy,
4314    }
4315}
4316
4317fn room_updated(room: &proto::Room, peer: &Peer) {
4318    broadcast(
4319        None,
4320        room.participants
4321            .iter()
4322            .filter_map(|participant| Some(participant.peer_id?.into())),
4323        |peer_id| {
4324            peer.send(
4325                peer_id,
4326                proto::RoomUpdated {
4327                    room: Some(room.clone()),
4328                },
4329            )
4330        },
4331    );
4332}
4333
4334fn channel_updated(
4335    channel: &db::channel::Model,
4336    room: &proto::Room,
4337    peer: &Peer,
4338    pool: &ConnectionPool,
4339) {
4340    let participants = room
4341        .participants
4342        .iter()
4343        .map(|p| p.user_id)
4344        .collect::<Vec<_>>();
4345
4346    broadcast(
4347        None,
4348        pool.channel_connection_ids(channel.root_id())
4349            .filter_map(|(channel_id, role)| {
4350                role.can_see_channel(channel.visibility)
4351                    .then_some(channel_id)
4352            }),
4353        |peer_id| {
4354            peer.send(
4355                peer_id,
4356                proto::UpdateChannels {
4357                    channel_participants: vec![proto::ChannelParticipants {
4358                        channel_id: channel.id.to_proto(),
4359                        participant_user_ids: participants.clone(),
4360                    }],
4361                    ..Default::default()
4362                },
4363            )
4364        },
4365    );
4366}
4367
4368async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4369    let db = session.db().await;
4370
4371    let contacts = db.get_contacts(user_id).await?;
4372    let busy = db.is_user_busy(user_id).await?;
4373
4374    let pool = session.connection_pool().await;
4375    let updated_contact = contact_for_user(user_id, busy, &pool);
4376    for contact in contacts {
4377        if let db::Contact::Accepted {
4378            user_id: contact_user_id,
4379            ..
4380        } = contact
4381        {
4382            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4383                session
4384                    .peer
4385                    .send(
4386                        contact_conn_id,
4387                        proto::UpdateContacts {
4388                            contacts: vec![updated_contact.clone()],
4389                            remove_contacts: Default::default(),
4390                            incoming_requests: Default::default(),
4391                            remove_incoming_requests: Default::default(),
4392                            outgoing_requests: Default::default(),
4393                            remove_outgoing_requests: Default::default(),
4394                        },
4395                    )
4396                    .trace_err();
4397            }
4398        }
4399    }
4400    Ok(())
4401}
4402
4403async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4404    let mut contacts_to_update = HashSet::default();
4405
4406    let room_id;
4407    let canceled_calls_to_user_ids;
4408    let livekit_room;
4409    let delete_livekit_room;
4410    let room;
4411    let channel;
4412
4413    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4414        contacts_to_update.insert(session.user_id());
4415
4416        for project in left_room.left_projects.values() {
4417            project_left(project, session);
4418        }
4419
4420        room_id = RoomId::from_proto(left_room.room.id);
4421        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4422        livekit_room = mem::take(&mut left_room.room.livekit_room);
4423        delete_livekit_room = left_room.deleted;
4424        room = mem::take(&mut left_room.room);
4425        channel = mem::take(&mut left_room.channel);
4426
4427        room_updated(&room, &session.peer);
4428    } else {
4429        return Ok(());
4430    }
4431
4432    if let Some(channel) = channel {
4433        channel_updated(
4434            &channel,
4435            &room,
4436            &session.peer,
4437            &*session.connection_pool().await,
4438        );
4439    }
4440
4441    {
4442        let pool = session.connection_pool().await;
4443        for canceled_user_id in canceled_calls_to_user_ids {
4444            for connection_id in pool.user_connection_ids(canceled_user_id) {
4445                session
4446                    .peer
4447                    .send(
4448                        connection_id,
4449                        proto::CallCanceled {
4450                            room_id: room_id.to_proto(),
4451                        },
4452                    )
4453                    .trace_err();
4454            }
4455            contacts_to_update.insert(canceled_user_id);
4456        }
4457    }
4458
4459    for contact_user_id in contacts_to_update {
4460        update_user_contacts(contact_user_id, session).await?;
4461    }
4462
4463    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4464        live_kit
4465            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4466            .await
4467            .trace_err();
4468
4469        if delete_livekit_room {
4470            live_kit.delete_room(livekit_room).await.trace_err();
4471        }
4472    }
4473
4474    Ok(())
4475}
4476
4477async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4478    let left_channel_buffers = session
4479        .db()
4480        .await
4481        .leave_channel_buffers(session.connection_id)
4482        .await?;
4483
4484    for left_buffer in left_channel_buffers {
4485        channel_buffer_updated(
4486            session.connection_id,
4487            left_buffer.connections,
4488            &proto::UpdateChannelBufferCollaborators {
4489                channel_id: left_buffer.channel_id.to_proto(),
4490                collaborators: left_buffer.collaborators,
4491            },
4492            &session.peer,
4493        );
4494    }
4495
4496    Ok(())
4497}
4498
4499fn project_left(project: &db::LeftProject, session: &Session) {
4500    for connection_id in &project.connection_ids {
4501        if project.should_unshare {
4502            session
4503                .peer
4504                .send(
4505                    *connection_id,
4506                    proto::UnshareProject {
4507                        project_id: project.id.to_proto(),
4508                    },
4509                )
4510                .trace_err();
4511        } else {
4512            session
4513                .peer
4514                .send(
4515                    *connection_id,
4516                    proto::RemoveProjectCollaborator {
4517                        project_id: project.id.to_proto(),
4518                        peer_id: Some(session.connection_id.into()),
4519                    },
4520                )
4521                .trace_err();
4522        }
4523    }
4524}
4525
4526pub trait ResultExt {
4527    type Ok;
4528
4529    fn trace_err(self) -> Option<Self::Ok>;
4530}
4531
4532impl<T, E> ResultExt for Result<T, E>
4533where
4534    E: std::fmt::Debug,
4535{
4536    type Ok = T;
4537
4538    #[track_caller]
4539    fn trace_err(self) -> Option<T> {
4540        match self {
4541            Ok(value) => Some(value),
4542            Err(error) => {
4543                tracing::error!("{:?}", error);
4544                None
4545            }
4546        }
4547    }
4548}