rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::stripe_client::StripeCustomerId;
   9use crate::{
  10    AppState, Error, Result, auth,
  11    db::{
  12        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  13        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  14        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  15        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  16    },
  17    executor::Executor,
  18};
  19use anyhow::{Context as _, anyhow, bail};
  20use async_tungstenite::tungstenite::{
  21    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  22};
  23use axum::{
  24    Extension, Router, TypedHeader,
  25    body::Body,
  26    extract::{
  27        ConnectInfo, WebSocketUpgrade,
  28        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  29    },
  30    headers::{Header, HeaderName},
  31    http::StatusCode,
  32    middleware,
  33    response::IntoResponse,
  34    routing::get,
  35};
  36use chrono::Utc;
  37use collections::{HashMap, HashSet};
  38pub use connection_pool::{ConnectionPool, ZedVersion};
  39use core::fmt::{self, Debug, Formatter};
  40use reqwest_client::ReqwestClient;
  41use rpc::proto::split_repository_update;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  46    stream::FuturesUnordered,
  47};
  48use prometheus::{IntGauge, register_int_gauge};
  49use rpc::{
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        Arc, OnceLock,
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{Semaphore, watch};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    Instrument,
  77    field::{self},
  78    info_span, instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111}
 112
 113impl Principal {
 114    fn user(&self) -> &User {
 115        match self {
 116            Principal::User(user) => user,
 117            Principal::Impersonated { user, .. } => user,
 118        }
 119    }
 120
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132        }
 133    }
 134}
 135
 136#[derive(Clone)]
 137struct Session {
 138    principal: Principal,
 139    connection_id: ConnectionId,
 140    db: Arc<tokio::sync::Mutex<DbHandle>>,
 141    peer: Arc<Peer>,
 142    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    system_id: Option<String>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn is_staff(&self) -> bool {
 173        match &self.principal {
 174            Principal::User(user) => user.admin,
 175            Principal::Impersonated { .. } => true,
 176        }
 177    }
 178
 179    fn user_id(&self) -> UserId {
 180        match &self.principal {
 181            Principal::User(user) => user.id,
 182            Principal::Impersonated { user, .. } => user.id,
 183        }
 184    }
 185
 186    pub fn email(&self) -> Option<String> {
 187        match &self.principal {
 188            Principal::User(user) => user.email_address.clone(),
 189            Principal::Impersonated { user, .. } => user.email_address.clone(),
 190        }
 191    }
 192}
 193
 194impl Debug for Session {
 195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 196        let mut result = f.debug_struct("Session");
 197        match &self.principal {
 198            Principal::User(user) => {
 199                result.field("user", &user.github_login);
 200            }
 201            Principal::Impersonated { user, admin } => {
 202                result.field("user", &user.github_login);
 203                result.field("impersonator", &admin.github_login);
 204            }
 205        }
 206        result.field("connection_id", &self.connection_id).finish()
 207    }
 208}
 209
 210struct DbHandle(Arc<Database>);
 211
 212impl Deref for DbHandle {
 213    type Target = Database;
 214
 215    fn deref(&self) -> &Self::Target {
 216        self.0.as_ref()
 217    }
 218}
 219
 220pub struct Server {
 221    id: parking_lot::Mutex<ServerId>,
 222    peer: Arc<Peer>,
 223    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 224    app_state: Arc<AppState>,
 225    handlers: HashMap<TypeId, MessageHandler>,
 226    teardown: watch::Sender<bool>,
 227}
 228
 229pub(crate) struct ConnectionPoolGuard<'a> {
 230    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 231    _not_send: PhantomData<Rc<()>>,
 232}
 233
 234#[derive(Serialize)]
 235pub struct ServerSnapshot<'a> {
 236    peer: &'a Peer,
 237    #[serde(serialize_with = "serialize_deref")]
 238    connection_pool: ConnectionPoolGuard<'a>,
 239}
 240
 241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 242where
 243    S: Serializer,
 244    T: Deref<Target = U>,
 245    U: Serialize,
 246{
 247    Serialize::serialize(value.deref(), serializer)
 248}
 249
 250impl Server {
 251    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 252        let mut server = Self {
 253            id: parking_lot::Mutex::new(id),
 254            peer: Peer::new(id.0 as u32),
 255            app_state: app_state.clone(),
 256            connection_pool: Default::default(),
 257            handlers: Default::default(),
 258            teardown: watch::channel(false).0,
 259        };
 260
 261        server
 262            .add_request_handler(ping)
 263            .add_request_handler(create_room)
 264            .add_request_handler(join_room)
 265            .add_request_handler(rejoin_room)
 266            .add_request_handler(leave_room)
 267            .add_request_handler(set_room_participant_role)
 268            .add_request_handler(call)
 269            .add_request_handler(cancel_call)
 270            .add_message_handler(decline_call)
 271            .add_request_handler(update_participant_location)
 272            .add_request_handler(share_project)
 273            .add_message_handler(unshare_project)
 274            .add_request_handler(join_project)
 275            .add_message_handler(leave_project)
 276            .add_request_handler(update_project)
 277            .add_request_handler(update_worktree)
 278            .add_request_handler(update_repository)
 279            .add_request_handler(remove_repository)
 280            .add_message_handler(start_language_server)
 281            .add_message_handler(update_language_server)
 282            .add_message_handler(update_diagnostic_summary)
 283            .add_message_handler(update_worktree_settings)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 288            .add_request_handler(forward_find_search_candidates_request)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 294            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 295            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 296            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 297            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 298            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 304            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 305            .add_request_handler(
 306                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 307            )
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 312            .add_request_handler(
 313                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 314            )
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 349            .add_message_handler(create_buffer_for_peer)
 350            .add_request_handler(update_buffer)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 357            .add_request_handler(get_users)
 358            .add_request_handler(fuzzy_search_users)
 359            .add_request_handler(request_contact)
 360            .add_request_handler(remove_contact)
 361            .add_request_handler(respond_to_contact_request)
 362            .add_message_handler(subscribe_to_channels)
 363            .add_request_handler(create_channel)
 364            .add_request_handler(delete_channel)
 365            .add_request_handler(invite_channel_member)
 366            .add_request_handler(remove_channel_member)
 367            .add_request_handler(set_channel_member_role)
 368            .add_request_handler(set_channel_visibility)
 369            .add_request_handler(rename_channel)
 370            .add_request_handler(join_channel_buffer)
 371            .add_request_handler(leave_channel_buffer)
 372            .add_message_handler(update_channel_buffer)
 373            .add_request_handler(rejoin_channel_buffers)
 374            .add_request_handler(get_channel_members)
 375            .add_request_handler(respond_to_channel_invite)
 376            .add_request_handler(join_channel)
 377            .add_request_handler(join_channel_chat)
 378            .add_message_handler(leave_channel_chat)
 379            .add_request_handler(send_channel_message)
 380            .add_request_handler(remove_channel_message)
 381            .add_request_handler(update_channel_message)
 382            .add_request_handler(get_channel_messages)
 383            .add_request_handler(get_channel_messages_by_id)
 384            .add_request_handler(get_notifications)
 385            .add_request_handler(mark_notification_as_read)
 386            .add_request_handler(move_channel)
 387            .add_request_handler(follow)
 388            .add_message_handler(unfollow)
 389            .add_message_handler(update_followers)
 390            .add_request_handler(get_private_user_info)
 391            .add_request_handler(get_llm_api_token)
 392            .add_request_handler(accept_terms_of_service)
 393            .add_message_handler(acknowledge_channel_message)
 394            .add_message_handler(acknowledge_buffer_version)
 395            .add_request_handler(get_supermaven_api_key)
 396            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 398            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 402            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 405            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 408            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 409            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 411            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 412            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 415            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 416            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 417            .add_message_handler(update_context);
 418
 419        Arc::new(server)
 420    }
 421
 422    pub async fn start(&self) -> Result<()> {
 423        let server_id = *self.id.lock();
 424        let app_state = self.app_state.clone();
 425        let peer = self.peer.clone();
 426        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 427        let pool = self.connection_pool.clone();
 428        let livekit_client = self.app_state.livekit_client.clone();
 429
 430        let span = info_span!("start server");
 431        self.app_state.executor.spawn_detached(
 432            async move {
 433                tracing::info!("waiting for cleanup timeout");
 434                timeout.await;
 435                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 436
 437                app_state
 438                    .db
 439                    .delete_stale_channel_chat_participants(
 440                        &app_state.config.zed_environment,
 441                        server_id,
 442                    )
 443                    .await
 444                    .trace_err();
 445
 446                if let Some((room_ids, channel_ids)) = app_state
 447                    .db
 448                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 449                    .await
 450                    .trace_err()
 451                {
 452                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 453                    tracing::info!(
 454                        stale_channel_buffer_count = channel_ids.len(),
 455                        "retrieved stale channel buffers"
 456                    );
 457
 458                    for channel_id in channel_ids {
 459                        if let Some(refreshed_channel_buffer) = app_state
 460                            .db
 461                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 462                            .await
 463                            .trace_err()
 464                        {
 465                            for connection_id in refreshed_channel_buffer.connection_ids {
 466                                peer.send(
 467                                    connection_id,
 468                                    proto::UpdateChannelBufferCollaborators {
 469                                        channel_id: channel_id.to_proto(),
 470                                        collaborators: refreshed_channel_buffer
 471                                            .collaborators
 472                                            .clone(),
 473                                    },
 474                                )
 475                                .trace_err();
 476                            }
 477                        }
 478                    }
 479
 480                    for room_id in room_ids {
 481                        let mut contacts_to_update = HashSet::default();
 482                        let mut canceled_calls_to_user_ids = Vec::new();
 483                        let mut livekit_room = String::new();
 484                        let mut delete_livekit_room = false;
 485
 486                        if let Some(mut refreshed_room) = app_state
 487                            .db
 488                            .clear_stale_room_participants(room_id, server_id)
 489                            .await
 490                            .trace_err()
 491                        {
 492                            tracing::info!(
 493                                room_id = room_id.0,
 494                                new_participant_count = refreshed_room.room.participants.len(),
 495                                "refreshed room"
 496                            );
 497                            room_updated(&refreshed_room.room, &peer);
 498                            if let Some(channel) = refreshed_room.channel.as_ref() {
 499                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 500                            }
 501                            contacts_to_update
 502                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 503                            contacts_to_update
 504                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 505                            canceled_calls_to_user_ids =
 506                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 507                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 508                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 509                        }
 510
 511                        {
 512                            let pool = pool.lock();
 513                            for canceled_user_id in canceled_calls_to_user_ids {
 514                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 515                                    peer.send(
 516                                        connection_id,
 517                                        proto::CallCanceled {
 518                                            room_id: room_id.to_proto(),
 519                                        },
 520                                    )
 521                                    .trace_err();
 522                                }
 523                            }
 524                        }
 525
 526                        for user_id in contacts_to_update {
 527                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 528                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 529                            if let Some((busy, contacts)) = busy.zip(contacts) {
 530                                let pool = pool.lock();
 531                                let updated_contact = contact_for_user(user_id, busy, &pool);
 532                                for contact in contacts {
 533                                    if let db::Contact::Accepted {
 534                                        user_id: contact_user_id,
 535                                        ..
 536                                    } = contact
 537                                    {
 538                                        for contact_conn_id in
 539                                            pool.user_connection_ids(contact_user_id)
 540                                        {
 541                                            peer.send(
 542                                                contact_conn_id,
 543                                                proto::UpdateContacts {
 544                                                    contacts: vec![updated_contact.clone()],
 545                                                    remove_contacts: Default::default(),
 546                                                    incoming_requests: Default::default(),
 547                                                    remove_incoming_requests: Default::default(),
 548                                                    outgoing_requests: Default::default(),
 549                                                    remove_outgoing_requests: Default::default(),
 550                                                },
 551                                            )
 552                                            .trace_err();
 553                                        }
 554                                    }
 555                                }
 556                            }
 557                        }
 558
 559                        if let Some(live_kit) = livekit_client.as_ref() {
 560                            if delete_livekit_room {
 561                                live_kit.delete_room(livekit_room).await.trace_err();
 562                            }
 563                        }
 564                    }
 565                }
 566
 567                app_state
 568                    .db
 569                    .delete_stale_channel_chat_participants(
 570                        &app_state.config.zed_environment,
 571                        server_id,
 572                    )
 573                    .await
 574                    .trace_err();
 575
 576                app_state
 577                    .db
 578                    .clear_old_worktree_entries(server_id)
 579                    .await
 580                    .trace_err();
 581
 582                app_state
 583                    .db
 584                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 585                    .await
 586                    .trace_err();
 587            }
 588            .instrument(span),
 589        );
 590        Ok(())
 591    }
 592
 593    pub fn teardown(&self) {
 594        self.peer.teardown();
 595        self.connection_pool.lock().reset();
 596        let _ = self.teardown.send(true);
 597    }
 598
 599    #[cfg(test)]
 600    pub fn reset(&self, id: ServerId) {
 601        self.teardown();
 602        *self.id.lock() = id;
 603        self.peer.reset(id.0 as u32);
 604        let _ = self.teardown.send(false);
 605    }
 606
 607    #[cfg(test)]
 608    pub fn id(&self) -> ServerId {
 609        *self.id.lock()
 610    }
 611
 612    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 613    where
 614        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 615        Fut: 'static + Send + Future<Output = Result<()>>,
 616        M: EnvelopedMessage,
 617    {
 618        let prev_handler = self.handlers.insert(
 619            TypeId::of::<M>(),
 620            Box::new(move |envelope, session| {
 621                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 622                let received_at = envelope.received_at;
 623                tracing::info!("message received");
 624                let start_time = Instant::now();
 625                let future = (handler)(*envelope, session);
 626                async move {
 627                    let result = future.await;
 628                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 629                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 630                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 631                    let payload_type = M::NAME;
 632
 633                    match result {
 634                        Err(error) => {
 635                            tracing::error!(
 636                                ?error,
 637                                total_duration_ms,
 638                                processing_duration_ms,
 639                                queue_duration_ms,
 640                                payload_type,
 641                                "error handling message"
 642                            )
 643                        }
 644                        Ok(()) => tracing::info!(
 645                            total_duration_ms,
 646                            processing_duration_ms,
 647                            queue_duration_ms,
 648                            "finished handling message"
 649                        ),
 650                    }
 651                }
 652                .boxed()
 653            }),
 654        );
 655        if prev_handler.is_some() {
 656            panic!("registered a handler for the same message twice");
 657        }
 658        self
 659    }
 660
 661    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 662    where
 663        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 664        Fut: 'static + Send + Future<Output = Result<()>>,
 665        M: EnvelopedMessage,
 666    {
 667        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 668        self
 669    }
 670
 671    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 674        Fut: Send + Future<Output = Result<()>>,
 675        M: RequestMessage,
 676    {
 677        let handler = Arc::new(handler);
 678        self.add_handler(move |envelope, session| {
 679            let receipt = envelope.receipt();
 680            let handler = handler.clone();
 681            async move {
 682                let peer = session.peer.clone();
 683                let responded = Arc::new(AtomicBool::default());
 684                let response = Response {
 685                    peer: peer.clone(),
 686                    responded: responded.clone(),
 687                    receipt,
 688                };
 689                match (handler)(envelope.payload, response, session).await {
 690                    Ok(()) => {
 691                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 692                            Ok(())
 693                        } else {
 694                            Err(anyhow!("handler did not send a response"))?
 695                        }
 696                    }
 697                    Err(error) => {
 698                        let proto_err = match &error {
 699                            Error::Internal(err) => err.to_proto(),
 700                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 701                        };
 702                        peer.respond_with_error(receipt, proto_err)?;
 703                        Err(error)
 704                    }
 705                }
 706            }
 707        })
 708    }
 709
 710    pub fn handle_connection(
 711        self: &Arc<Self>,
 712        connection: Connection,
 713        address: String,
 714        principal: Principal,
 715        zed_version: ZedVersion,
 716        geoip_country_code: Option<String>,
 717        system_id: Option<String>,
 718        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 719        executor: Executor,
 720    ) -> impl Future<Output = ()> + use<> {
 721        let this = self.clone();
 722        let span = info_span!("handle connection", %address,
 723            connection_id=field::Empty,
 724            user_id=field::Empty,
 725            login=field::Empty,
 726            impersonator=field::Empty,
 727            geoip_country_code=field::Empty
 728        );
 729        principal.update_span(&span);
 730        if let Some(country_code) = geoip_country_code.as_ref() {
 731            span.record("geoip_country_code", country_code);
 732        }
 733
 734        let mut teardown = self.teardown.subscribe();
 735        async move {
 736            if *teardown.borrow() {
 737                tracing::error!("server is tearing down");
 738                return
 739            }
 740            let (connection_id, handle_io, mut incoming_rx) = this
 741                .peer
 742                .add_connection(connection, {
 743                    let executor = executor.clone();
 744                    move |duration| executor.sleep(duration)
 745                });
 746            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 747
 748            tracing::info!("connection opened");
 749
 750            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 751            let http_client = match ReqwestClient::user_agent(&user_agent) {
 752                Ok(http_client) => Arc::new(http_client),
 753                Err(error) => {
 754                    tracing::error!(?error, "failed to create HTTP client");
 755                    return;
 756                }
 757            };
 758
 759            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 760                    supermaven_admin_api_key.to_string(),
 761                    http_client.clone(),
 762                )));
 763
 764            let session = Session {
 765                principal: principal.clone(),
 766                connection_id,
 767                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 768                peer: this.peer.clone(),
 769                connection_pool: this.connection_pool.clone(),
 770                app_state: this.app_state.clone(),
 771                geoip_country_code,
 772                system_id,
 773                _executor: executor.clone(),
 774                supermaven_client,
 775            };
 776
 777            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 778                tracing::error!(?error, "failed to send initial client update");
 779                return;
 780            }
 781
 782            let handle_io = handle_io.fuse();
 783            futures::pin_mut!(handle_io);
 784
 785            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 786            // This prevents deadlocks when e.g., client A performs a request to client B and
 787            // client B performs a request to client A. If both clients stop processing further
 788            // messages until their respective request completes, they won't have a chance to
 789            // respond to the other client's request and cause a deadlock.
 790            //
 791            // This arrangement ensures we will attempt to process earlier messages first, but fall
 792            // back to processing messages arrived later in the spirit of making progress.
 793            let mut foreground_message_handlers = FuturesUnordered::new();
 794            let concurrent_handlers = Arc::new(Semaphore::new(256));
 795            loop {
 796                let next_message = async {
 797                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 798                    let message = incoming_rx.next().await;
 799                    (permit, message)
 800                }.fuse();
 801                futures::pin_mut!(next_message);
 802                futures::select_biased! {
 803                    _ = teardown.changed().fuse() => return,
 804                    result = handle_io => {
 805                        if let Err(error) = result {
 806                            tracing::error!(?error, "error handling I/O");
 807                        }
 808                        break;
 809                    }
 810                    _ = foreground_message_handlers.next() => {}
 811                    next_message = next_message => {
 812                        let (permit, message) = next_message;
 813                        if let Some(message) = message {
 814                            let type_name = message.payload_type_name();
 815                            // note: we copy all the fields from the parent span so we can query them in the logs.
 816                            // (https://github.com/tokio-rs/tracing/issues/2670).
 817                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 818                                user_id=field::Empty,
 819                                login=field::Empty,
 820                                impersonator=field::Empty,
 821                            );
 822                            principal.update_span(&span);
 823                            let span_enter = span.enter();
 824                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 825                                let is_background = message.is_background();
 826                                let handle_message = (handler)(message, session.clone());
 827                                drop(span_enter);
 828
 829                                let handle_message = async move {
 830                                    handle_message.await;
 831                                    drop(permit);
 832                                }.instrument(span);
 833                                if is_background {
 834                                    executor.spawn_detached(handle_message);
 835                                } else {
 836                                    foreground_message_handlers.push(handle_message);
 837                                }
 838                            } else {
 839                                tracing::error!("no message handler");
 840                            }
 841                        } else {
 842                            tracing::info!("connection closed");
 843                            break;
 844                        }
 845                    }
 846                }
 847            }
 848
 849            drop(foreground_message_handlers);
 850            tracing::info!("signing out");
 851            if let Err(error) = connection_lost(session, teardown, executor).await {
 852                tracing::error!(?error, "error signing out");
 853            }
 854
 855        }.instrument(span)
 856    }
 857
 858    async fn send_initial_client_update(
 859        &self,
 860        connection_id: ConnectionId,
 861        zed_version: ZedVersion,
 862        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 863        session: &Session,
 864    ) -> Result<()> {
 865        self.peer.send(
 866            connection_id,
 867            proto::Hello {
 868                peer_id: Some(connection_id.into()),
 869            },
 870        )?;
 871        tracing::info!("sent hello message");
 872        if let Some(send_connection_id) = send_connection_id.take() {
 873            let _ = send_connection_id.send(connection_id);
 874        }
 875
 876        match &session.principal {
 877            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 878                if !user.connected_once {
 879                    self.peer.send(connection_id, proto::ShowContacts {})?;
 880                    self.app_state
 881                        .db
 882                        .set_user_connected_once(user.id, true)
 883                        .await?;
 884                }
 885
 886                update_user_plan(session).await?;
 887
 888                let contacts = self.app_state.db.get_contacts(user.id).await?;
 889
 890                {
 891                    let mut pool = self.connection_pool.lock();
 892                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 893                    self.peer.send(
 894                        connection_id,
 895                        build_initial_contacts_update(contacts, &pool),
 896                    )?;
 897                }
 898
 899                if should_auto_subscribe_to_channels(zed_version) {
 900                    subscribe_user_to_channels(user.id, session).await?;
 901                }
 902
 903                if let Some(incoming_call) =
 904                    self.app_state.db.incoming_call_for_user(user.id).await?
 905                {
 906                    self.peer.send(connection_id, incoming_call)?;
 907                }
 908
 909                update_user_contacts(user.id, session).await?;
 910            }
 911        }
 912
 913        Ok(())
 914    }
 915
 916    pub async fn invite_code_redeemed(
 917        self: &Arc<Self>,
 918        inviter_id: UserId,
 919        invitee_id: UserId,
 920    ) -> Result<()> {
 921        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 922            if let Some(code) = &user.invite_code {
 923                let pool = self.connection_pool.lock();
 924                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 925                for connection_id in pool.user_connection_ids(inviter_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateContacts {
 929                            contacts: vec![invitee_contact.clone()],
 930                            ..Default::default()
 931                        },
 932                    )?;
 933                    self.peer.send(
 934                        connection_id,
 935                        proto::UpdateInviteInfo {
 936                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 937                            count: user.invite_count as u32,
 938                        },
 939                    )?;
 940                }
 941            }
 942        }
 943        Ok(())
 944    }
 945
 946    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 947        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 948            if let Some(invite_code) = &user.invite_code {
 949                let pool = self.connection_pool.lock();
 950                for connection_id in pool.user_connection_ids(user_id) {
 951                    self.peer.send(
 952                        connection_id,
 953                        proto::UpdateInviteInfo {
 954                            url: format!(
 955                                "{}{}",
 956                                self.app_state.config.invite_link_prefix, invite_code
 957                            ),
 958                            count: user.invite_count as u32,
 959                        },
 960                    )?;
 961                }
 962            }
 963        }
 964        Ok(())
 965    }
 966
 967    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 968        let user = self
 969            .app_state
 970            .db
 971            .get_user_by_id(user_id)
 972            .await?
 973            .context("user not found")?;
 974
 975        let update_user_plan = make_update_user_plan_message(
 976            &user,
 977            user.admin,
 978            &self.app_state.db,
 979            self.app_state.llm_db.clone(),
 980        )
 981        .await?;
 982
 983        let pool = self.connection_pool.lock();
 984        for connection_id in pool.user_connection_ids(user_id) {
 985            self.peer
 986                .send(connection_id, update_user_plan.clone())
 987                .trace_err();
 988        }
 989
 990        Ok(())
 991    }
 992
 993    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 994        let pool = self.connection_pool.lock();
 995        for connection_id in pool.user_connection_ids(user_id) {
 996            self.peer
 997                .send(connection_id, proto::RefreshLlmToken {})
 998                .trace_err();
 999        }
1000    }
1001
1002    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1003        ServerSnapshot {
1004            connection_pool: ConnectionPoolGuard {
1005                guard: self.connection_pool.lock(),
1006                _not_send: PhantomData,
1007            },
1008            peer: &self.peer,
1009        }
1010    }
1011}
1012
1013impl Deref for ConnectionPoolGuard<'_> {
1014    type Target = ConnectionPool;
1015
1016    fn deref(&self) -> &Self::Target {
1017        &self.guard
1018    }
1019}
1020
1021impl DerefMut for ConnectionPoolGuard<'_> {
1022    fn deref_mut(&mut self) -> &mut Self::Target {
1023        &mut self.guard
1024    }
1025}
1026
1027impl Drop for ConnectionPoolGuard<'_> {
1028    fn drop(&mut self) {
1029        #[cfg(test)]
1030        self.check_invariants();
1031    }
1032}
1033
1034fn broadcast<F>(
1035    sender_id: Option<ConnectionId>,
1036    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1037    mut f: F,
1038) where
1039    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1040{
1041    for receiver_id in receiver_ids {
1042        if Some(receiver_id) != sender_id {
1043            if let Err(error) = f(receiver_id) {
1044                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1045            }
1046        }
1047    }
1048}
1049
1050pub struct ProtocolVersion(u32);
1051
1052impl Header for ProtocolVersion {
1053    fn name() -> &'static HeaderName {
1054        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1055        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1056    }
1057
1058    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1059    where
1060        Self: Sized,
1061        I: Iterator<Item = &'i axum::http::HeaderValue>,
1062    {
1063        let version = values
1064            .next()
1065            .ok_or_else(axum::headers::Error::invalid)?
1066            .to_str()
1067            .map_err(|_| axum::headers::Error::invalid())?
1068            .parse()
1069            .map_err(|_| axum::headers::Error::invalid())?;
1070        Ok(Self(version))
1071    }
1072
1073    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1074        values.extend([self.0.to_string().parse().unwrap()]);
1075    }
1076}
1077
1078pub struct AppVersionHeader(SemanticVersion);
1079impl Header for AppVersionHeader {
1080    fn name() -> &'static HeaderName {
1081        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1082        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1083    }
1084
1085    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1086    where
1087        Self: Sized,
1088        I: Iterator<Item = &'i axum::http::HeaderValue>,
1089    {
1090        let version = values
1091            .next()
1092            .ok_or_else(axum::headers::Error::invalid)?
1093            .to_str()
1094            .map_err(|_| axum::headers::Error::invalid())?
1095            .parse()
1096            .map_err(|_| axum::headers::Error::invalid())?;
1097        Ok(Self(version))
1098    }
1099
1100    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1101        values.extend([self.0.to_string().parse().unwrap()]);
1102    }
1103}
1104
1105pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1106    Router::new()
1107        .route("/rpc", get(handle_websocket_request))
1108        .layer(
1109            ServiceBuilder::new()
1110                .layer(Extension(server.app_state.clone()))
1111                .layer(middleware::from_fn(auth::validate_header)),
1112        )
1113        .route("/metrics", get(handle_metrics))
1114        .layer(Extension(server))
1115}
1116
1117pub async fn handle_websocket_request(
1118    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1119    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1120    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1121    Extension(server): Extension<Arc<Server>>,
1122    Extension(principal): Extension<Principal>,
1123    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1124    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1125    ws: WebSocketUpgrade,
1126) -> axum::response::Response {
1127    if protocol_version != rpc::PROTOCOL_VERSION {
1128        return (
1129            StatusCode::UPGRADE_REQUIRED,
1130            "client must be upgraded".to_string(),
1131        )
1132            .into_response();
1133    }
1134
1135    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1136        return (
1137            StatusCode::UPGRADE_REQUIRED,
1138            "no version header found".to_string(),
1139        )
1140            .into_response();
1141    };
1142
1143    if !version.can_collaborate() {
1144        return (
1145            StatusCode::UPGRADE_REQUIRED,
1146            "client must be upgraded".to_string(),
1147        )
1148            .into_response();
1149    }
1150
1151    let socket_address = socket_address.to_string();
1152    ws.on_upgrade(move |socket| {
1153        let socket = socket
1154            .map_ok(to_tungstenite_message)
1155            .err_into()
1156            .with(|message| async move { to_axum_message(message) });
1157        let connection = Connection::new(Box::pin(socket));
1158        async move {
1159            server
1160                .handle_connection(
1161                    connection,
1162                    socket_address,
1163                    principal,
1164                    version,
1165                    country_code_header.map(|header| header.to_string()),
1166                    system_id_header.map(|header| header.to_string()),
1167                    None,
1168                    Executor::Production,
1169                )
1170                .await;
1171        }
1172    })
1173}
1174
1175pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1176    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1177    let connections_metric = CONNECTIONS_METRIC
1178        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1179
1180    let connections = server
1181        .connection_pool
1182        .lock()
1183        .connections()
1184        .filter(|connection| !connection.admin)
1185        .count();
1186    connections_metric.set(connections as _);
1187
1188    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1189    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1190        register_int_gauge!(
1191            "shared_projects",
1192            "number of open projects with one or more guests"
1193        )
1194        .unwrap()
1195    });
1196
1197    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1198    shared_projects_metric.set(shared_projects as _);
1199
1200    let encoder = prometheus::TextEncoder::new();
1201    let metric_families = prometheus::gather();
1202    let encoded_metrics = encoder
1203        .encode_to_string(&metric_families)
1204        .map_err(|err| anyhow!("{err}"))?;
1205    Ok(encoded_metrics)
1206}
1207
1208#[instrument(err, skip(executor))]
1209async fn connection_lost(
1210    session: Session,
1211    mut teardown: watch::Receiver<bool>,
1212    executor: Executor,
1213) -> Result<()> {
1214    session.peer.disconnect(session.connection_id);
1215    session
1216        .connection_pool()
1217        .await
1218        .remove_connection(session.connection_id)?;
1219
1220    session
1221        .db()
1222        .await
1223        .connection_lost(session.connection_id)
1224        .await
1225        .trace_err();
1226
1227    futures::select_biased! {
1228        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1229
1230            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1231            leave_room_for_session(&session, session.connection_id).await.trace_err();
1232            leave_channel_buffers_for_session(&session)
1233                .await
1234                .trace_err();
1235
1236            if !session
1237                .connection_pool()
1238                .await
1239                .is_user_online(session.user_id())
1240            {
1241                let db = session.db().await;
1242                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1243                    room_updated(&room, &session.peer);
1244                }
1245            }
1246
1247            update_user_contacts(session.user_id(), &session).await?;
1248        },
1249        _ = teardown.changed().fuse() => {}
1250    }
1251
1252    Ok(())
1253}
1254
1255/// Acknowledges a ping from a client, used to keep the connection alive.
1256async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1257    response.send(proto::Ack {})?;
1258    Ok(())
1259}
1260
1261/// Creates a new room for calling (outside of channels)
1262async fn create_room(
1263    _request: proto::CreateRoom,
1264    response: Response<proto::CreateRoom>,
1265    session: Session,
1266) -> Result<()> {
1267    let livekit_room = nanoid::nanoid!(30);
1268
1269    let live_kit_connection_info = util::maybe!(async {
1270        let live_kit = session.app_state.livekit_client.as_ref();
1271        let live_kit = live_kit?;
1272        let user_id = session.user_id().to_string();
1273
1274        let token = live_kit
1275            .room_token(&livekit_room, &user_id.to_string())
1276            .trace_err()?;
1277
1278        Some(proto::LiveKitConnectionInfo {
1279            server_url: live_kit.url().into(),
1280            token,
1281            can_publish: true,
1282        })
1283    })
1284    .await;
1285
1286    let room = session
1287        .db()
1288        .await
1289        .create_room(session.user_id(), session.connection_id, &livekit_room)
1290        .await?;
1291
1292    response.send(proto::CreateRoomResponse {
1293        room: Some(room.clone()),
1294        live_kit_connection_info,
1295    })?;
1296
1297    update_user_contacts(session.user_id(), &session).await?;
1298    Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303    request: proto::JoinRoom,
1304    response: Response<proto::JoinRoom>,
1305    session: Session,
1306) -> Result<()> {
1307    let room_id = RoomId::from_proto(request.id);
1308
1309    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311    if let Some(channel_id) = channel_id {
1312        return join_channel_internal(channel_id, Box::new(response), session).await;
1313    }
1314
1315    let joined_room = {
1316        let room = session
1317            .db()
1318            .await
1319            .join_room(room_id, session.user_id(), session.connection_id)
1320            .await?;
1321        room_updated(&room.room, &session.peer);
1322        room.into_inner()
1323    };
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id())
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340
1341    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342    {
1343        live_kit
1344            .room_token(
1345                &joined_room.room.livekit_room,
1346                &session.user_id().to_string(),
1347            )
1348            .trace_err()
1349            .map(|token| proto::LiveKitConnectionInfo {
1350                server_url: live_kit.url().into(),
1351                token,
1352                can_publish: true,
1353            })
1354    } else {
1355        None
1356    };
1357
1358    response.send(proto::JoinRoomResponse {
1359        room: Some(joined_room.room),
1360        channel_id: None,
1361        live_kit_connection_info,
1362    })?;
1363
1364    update_user_contacts(session.user_id(), &session).await?;
1365    Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370    request: proto::RejoinRoom,
1371    response: Response<proto::RejoinRoom>,
1372    session: Session,
1373) -> Result<()> {
1374    let room;
1375    let channel;
1376    {
1377        let mut rejoined_room = session
1378            .db()
1379            .await
1380            .rejoin_room(request, session.user_id(), session.connection_id)
1381            .await?;
1382
1383        response.send(proto::RejoinRoomResponse {
1384            room: Some(rejoined_room.room.clone()),
1385            reshared_projects: rejoined_room
1386                .reshared_projects
1387                .iter()
1388                .map(|project| proto::ResharedProject {
1389                    id: project.id.to_proto(),
1390                    collaborators: project
1391                        .collaborators
1392                        .iter()
1393                        .map(|collaborator| collaborator.to_proto())
1394                        .collect(),
1395                })
1396                .collect(),
1397            rejoined_projects: rejoined_room
1398                .rejoined_projects
1399                .iter()
1400                .map(|rejoined_project| rejoined_project.to_proto())
1401                .collect(),
1402        })?;
1403        room_updated(&rejoined_room.room, &session.peer);
1404
1405        for project in &rejoined_room.reshared_projects {
1406            for collaborator in &project.collaborators {
1407                session
1408                    .peer
1409                    .send(
1410                        collaborator.connection_id,
1411                        proto::UpdateProjectCollaborator {
1412                            project_id: project.id.to_proto(),
1413                            old_peer_id: Some(project.old_connection_id.into()),
1414                            new_peer_id: Some(session.connection_id.into()),
1415                        },
1416                    )
1417                    .trace_err();
1418            }
1419
1420            broadcast(
1421                Some(session.connection_id),
1422                project
1423                    .collaborators
1424                    .iter()
1425                    .map(|collaborator| collaborator.connection_id),
1426                |connection_id| {
1427                    session.peer.forward_send(
1428                        session.connection_id,
1429                        connection_id,
1430                        proto::UpdateProject {
1431                            project_id: project.id.to_proto(),
1432                            worktrees: project.worktrees.clone(),
1433                        },
1434                    )
1435                },
1436            );
1437        }
1438
1439        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441        let rejoined_room = rejoined_room.into_inner();
1442
1443        room = rejoined_room.room;
1444        channel = rejoined_room.channel;
1445    }
1446
1447    if let Some(channel) = channel {
1448        channel_updated(
1449            &channel,
1450            &room,
1451            &session.peer,
1452            &*session.connection_pool().await,
1453        );
1454    }
1455
1456    update_user_contacts(session.user_id(), &session).await?;
1457    Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461    rejoined_projects: &mut Vec<RejoinedProject>,
1462    session: &Session,
1463) -> Result<()> {
1464    for project in rejoined_projects.iter() {
1465        for collaborator in &project.collaborators {
1466            session
1467                .peer
1468                .send(
1469                    collaborator.connection_id,
1470                    proto::UpdateProjectCollaborator {
1471                        project_id: project.id.to_proto(),
1472                        old_peer_id: Some(project.old_connection_id.into()),
1473                        new_peer_id: Some(session.connection_id.into()),
1474                    },
1475                )
1476                .trace_err();
1477        }
1478    }
1479
1480    for project in rejoined_projects {
1481        for worktree in mem::take(&mut project.worktrees) {
1482            // Stream this worktree's entries.
1483            let message = proto::UpdateWorktree {
1484                project_id: project.id.to_proto(),
1485                worktree_id: worktree.id,
1486                abs_path: worktree.abs_path.clone(),
1487                root_name: worktree.root_name,
1488                updated_entries: worktree.updated_entries,
1489                removed_entries: worktree.removed_entries,
1490                scan_id: worktree.scan_id,
1491                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492                updated_repositories: worktree.updated_repositories,
1493                removed_repositories: worktree.removed_repositories,
1494            };
1495            for update in proto::split_worktree_update(message) {
1496                session.peer.send(session.connection_id, update)?;
1497            }
1498
1499            // Stream this worktree's diagnostics.
1500            for summary in worktree.diagnostic_summaries {
1501                session.peer.send(
1502                    session.connection_id,
1503                    proto::UpdateDiagnosticSummary {
1504                        project_id: project.id.to_proto(),
1505                        worktree_id: worktree.id,
1506                        summary: Some(summary),
1507                    },
1508                )?;
1509            }
1510
1511            for settings_file in worktree.settings_files {
1512                session.peer.send(
1513                    session.connection_id,
1514                    proto::UpdateWorktreeSettings {
1515                        project_id: project.id.to_proto(),
1516                        worktree_id: worktree.id,
1517                        path: settings_file.path,
1518                        content: Some(settings_file.content),
1519                        kind: Some(settings_file.kind.to_proto().into()),
1520                    },
1521                )?;
1522            }
1523        }
1524
1525        for repository in mem::take(&mut project.updated_repositories) {
1526            for update in split_repository_update(repository) {
1527                session.peer.send(session.connection_id, update)?;
1528            }
1529        }
1530
1531        for id in mem::take(&mut project.removed_repositories) {
1532            session.peer.send(
1533                session.connection_id,
1534                proto::RemoveRepository {
1535                    project_id: project.id.to_proto(),
1536                    id,
1537                },
1538            )?;
1539        }
1540    }
1541
1542    Ok(())
1543}
1544
1545/// leave room disconnects from the room.
1546async fn leave_room(
1547    _: proto::LeaveRoom,
1548    response: Response<proto::LeaveRoom>,
1549    session: Session,
1550) -> Result<()> {
1551    leave_room_for_session(&session, session.connection_id).await?;
1552    response.send(proto::Ack {})?;
1553    Ok(())
1554}
1555
1556/// Updates the permissions of someone else in the room.
1557async fn set_room_participant_role(
1558    request: proto::SetRoomParticipantRole,
1559    response: Response<proto::SetRoomParticipantRole>,
1560    session: Session,
1561) -> Result<()> {
1562    let user_id = UserId::from_proto(request.user_id);
1563    let role = ChannelRole::from(request.role());
1564
1565    let (livekit_room, can_publish) = {
1566        let room = session
1567            .db()
1568            .await
1569            .set_room_participant_role(
1570                session.user_id(),
1571                RoomId::from_proto(request.room_id),
1572                user_id,
1573                role,
1574            )
1575            .await?;
1576
1577        let livekit_room = room.livekit_room.clone();
1578        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1579        room_updated(&room, &session.peer);
1580        (livekit_room, can_publish)
1581    };
1582
1583    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1584        live_kit
1585            .update_participant(
1586                livekit_room.clone(),
1587                request.user_id.to_string(),
1588                livekit_api::proto::ParticipantPermission {
1589                    can_subscribe: true,
1590                    can_publish,
1591                    can_publish_data: can_publish,
1592                    hidden: false,
1593                    recorder: false,
1594                },
1595            )
1596            .await
1597            .trace_err();
1598    }
1599
1600    response.send(proto::Ack {})?;
1601    Ok(())
1602}
1603
1604/// Call someone else into the current room
1605async fn call(
1606    request: proto::Call,
1607    response: Response<proto::Call>,
1608    session: Session,
1609) -> Result<()> {
1610    let room_id = RoomId::from_proto(request.room_id);
1611    let calling_user_id = session.user_id();
1612    let calling_connection_id = session.connection_id;
1613    let called_user_id = UserId::from_proto(request.called_user_id);
1614    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1615    if !session
1616        .db()
1617        .await
1618        .has_contact(calling_user_id, called_user_id)
1619        .await?
1620    {
1621        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1622    }
1623
1624    let incoming_call = {
1625        let (room, incoming_call) = &mut *session
1626            .db()
1627            .await
1628            .call(
1629                room_id,
1630                calling_user_id,
1631                calling_connection_id,
1632                called_user_id,
1633                initial_project_id,
1634            )
1635            .await?;
1636        room_updated(room, &session.peer);
1637        mem::take(incoming_call)
1638    };
1639    update_user_contacts(called_user_id, &session).await?;
1640
1641    let mut calls = session
1642        .connection_pool()
1643        .await
1644        .user_connection_ids(called_user_id)
1645        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1646        .collect::<FuturesUnordered<_>>();
1647
1648    while let Some(call_response) = calls.next().await {
1649        match call_response.as_ref() {
1650            Ok(_) => {
1651                response.send(proto::Ack {})?;
1652                return Ok(());
1653            }
1654            Err(_) => {
1655                call_response.trace_err();
1656            }
1657        }
1658    }
1659
1660    {
1661        let room = session
1662            .db()
1663            .await
1664            .call_failed(room_id, called_user_id)
1665            .await?;
1666        room_updated(&room, &session.peer);
1667    }
1668    update_user_contacts(called_user_id, &session).await?;
1669
1670    Err(anyhow!("failed to ring user"))?
1671}
1672
1673/// Cancel an outgoing call.
1674async fn cancel_call(
1675    request: proto::CancelCall,
1676    response: Response<proto::CancelCall>,
1677    session: Session,
1678) -> Result<()> {
1679    let called_user_id = UserId::from_proto(request.called_user_id);
1680    let room_id = RoomId::from_proto(request.room_id);
1681    {
1682        let room = session
1683            .db()
1684            .await
1685            .cancel_call(room_id, session.connection_id, called_user_id)
1686            .await?;
1687        room_updated(&room, &session.peer);
1688    }
1689
1690    for connection_id in session
1691        .connection_pool()
1692        .await
1693        .user_connection_ids(called_user_id)
1694    {
1695        session
1696            .peer
1697            .send(
1698                connection_id,
1699                proto::CallCanceled {
1700                    room_id: room_id.to_proto(),
1701                },
1702            )
1703            .trace_err();
1704    }
1705    response.send(proto::Ack {})?;
1706
1707    update_user_contacts(called_user_id, &session).await?;
1708    Ok(())
1709}
1710
1711/// Decline an incoming call.
1712async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1713    let room_id = RoomId::from_proto(message.room_id);
1714    {
1715        let room = session
1716            .db()
1717            .await
1718            .decline_call(Some(room_id), session.user_id())
1719            .await?
1720            .context("declining call")?;
1721        room_updated(&room, &session.peer);
1722    }
1723
1724    for connection_id in session
1725        .connection_pool()
1726        .await
1727        .user_connection_ids(session.user_id())
1728    {
1729        session
1730            .peer
1731            .send(
1732                connection_id,
1733                proto::CallCanceled {
1734                    room_id: room_id.to_proto(),
1735                },
1736            )
1737            .trace_err();
1738    }
1739    update_user_contacts(session.user_id(), &session).await?;
1740    Ok(())
1741}
1742
1743/// Updates other participants in the room with your current location.
1744async fn update_participant_location(
1745    request: proto::UpdateParticipantLocation,
1746    response: Response<proto::UpdateParticipantLocation>,
1747    session: Session,
1748) -> Result<()> {
1749    let room_id = RoomId::from_proto(request.room_id);
1750    let location = request.location.context("invalid location")?;
1751
1752    let db = session.db().await;
1753    let room = db
1754        .update_room_participant_location(room_id, session.connection_id, location)
1755        .await?;
1756
1757    room_updated(&room, &session.peer);
1758    response.send(proto::Ack {})?;
1759    Ok(())
1760}
1761
1762/// Share a project into the room.
1763async fn share_project(
1764    request: proto::ShareProject,
1765    response: Response<proto::ShareProject>,
1766    session: Session,
1767) -> Result<()> {
1768    let (project_id, room) = &*session
1769        .db()
1770        .await
1771        .share_project(
1772            RoomId::from_proto(request.room_id),
1773            session.connection_id,
1774            &request.worktrees,
1775            request.is_ssh_project,
1776        )
1777        .await?;
1778    response.send(proto::ShareProjectResponse {
1779        project_id: project_id.to_proto(),
1780    })?;
1781    room_updated(room, &session.peer);
1782
1783    Ok(())
1784}
1785
1786/// Unshare a project from the room.
1787async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1788    let project_id = ProjectId::from_proto(message.project_id);
1789    unshare_project_internal(project_id, session.connection_id, &session).await
1790}
1791
1792async fn unshare_project_internal(
1793    project_id: ProjectId,
1794    connection_id: ConnectionId,
1795    session: &Session,
1796) -> Result<()> {
1797    let delete = {
1798        let room_guard = session
1799            .db()
1800            .await
1801            .unshare_project(project_id, connection_id)
1802            .await?;
1803
1804        let (delete, room, guest_connection_ids) = &*room_guard;
1805
1806        let message = proto::UnshareProject {
1807            project_id: project_id.to_proto(),
1808        };
1809
1810        broadcast(
1811            Some(connection_id),
1812            guest_connection_ids.iter().copied(),
1813            |conn_id| session.peer.send(conn_id, message.clone()),
1814        );
1815        if let Some(room) = room {
1816            room_updated(room, &session.peer);
1817        }
1818
1819        *delete
1820    };
1821
1822    if delete {
1823        let db = session.db().await;
1824        db.delete_project(project_id).await?;
1825    }
1826
1827    Ok(())
1828}
1829
1830/// Join someone elses shared project.
1831async fn join_project(
1832    request: proto::JoinProject,
1833    response: Response<proto::JoinProject>,
1834    session: Session,
1835) -> Result<()> {
1836    let project_id = ProjectId::from_proto(request.project_id);
1837
1838    tracing::info!(%project_id, "join project");
1839
1840    let db = session.db().await;
1841    let (project, replica_id) = &mut *db
1842        .join_project(
1843            project_id,
1844            session.connection_id,
1845            session.user_id(),
1846            request.committer_name.clone(),
1847            request.committer_email.clone(),
1848        )
1849        .await?;
1850    drop(db);
1851    tracing::info!(%project_id, "join remote project");
1852    let collaborators = project
1853        .collaborators
1854        .iter()
1855        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1856        .map(|collaborator| collaborator.to_proto())
1857        .collect::<Vec<_>>();
1858    let project_id = project.id;
1859    let guest_user_id = session.user_id();
1860
1861    let worktrees = project
1862        .worktrees
1863        .iter()
1864        .map(|(id, worktree)| proto::WorktreeMetadata {
1865            id: *id,
1866            root_name: worktree.root_name.clone(),
1867            visible: worktree.visible,
1868            abs_path: worktree.abs_path.clone(),
1869        })
1870        .collect::<Vec<_>>();
1871
1872    let add_project_collaborator = proto::AddProjectCollaborator {
1873        project_id: project_id.to_proto(),
1874        collaborator: Some(proto::Collaborator {
1875            peer_id: Some(session.connection_id.into()),
1876            replica_id: replica_id.0 as u32,
1877            user_id: guest_user_id.to_proto(),
1878            is_host: false,
1879            committer_name: request.committer_name.clone(),
1880            committer_email: request.committer_email.clone(),
1881        }),
1882    };
1883
1884    for collaborator in &collaborators {
1885        session
1886            .peer
1887            .send(
1888                collaborator.peer_id.unwrap().into(),
1889                add_project_collaborator.clone(),
1890            )
1891            .trace_err();
1892    }
1893
1894    // First, we send the metadata associated with each worktree.
1895    response.send(proto::JoinProjectResponse {
1896        project_id: project.id.0 as u64,
1897        worktrees: worktrees.clone(),
1898        replica_id: replica_id.0 as u32,
1899        collaborators: collaborators.clone(),
1900        language_servers: project.language_servers.clone(),
1901        role: project.role.into(),
1902    })?;
1903
1904    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1905        // Stream this worktree's entries.
1906        let message = proto::UpdateWorktree {
1907            project_id: project_id.to_proto(),
1908            worktree_id,
1909            abs_path: worktree.abs_path.clone(),
1910            root_name: worktree.root_name,
1911            updated_entries: worktree.entries,
1912            removed_entries: Default::default(),
1913            scan_id: worktree.scan_id,
1914            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1915            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1916            removed_repositories: Default::default(),
1917        };
1918        for update in proto::split_worktree_update(message) {
1919            session.peer.send(session.connection_id, update.clone())?;
1920        }
1921
1922        // Stream this worktree's diagnostics.
1923        for summary in worktree.diagnostic_summaries {
1924            session.peer.send(
1925                session.connection_id,
1926                proto::UpdateDiagnosticSummary {
1927                    project_id: project_id.to_proto(),
1928                    worktree_id: worktree.id,
1929                    summary: Some(summary),
1930                },
1931            )?;
1932        }
1933
1934        for settings_file in worktree.settings_files {
1935            session.peer.send(
1936                session.connection_id,
1937                proto::UpdateWorktreeSettings {
1938                    project_id: project_id.to_proto(),
1939                    worktree_id: worktree.id,
1940                    path: settings_file.path,
1941                    content: Some(settings_file.content),
1942                    kind: Some(settings_file.kind.to_proto() as i32),
1943                },
1944            )?;
1945        }
1946    }
1947
1948    for repository in mem::take(&mut project.repositories) {
1949        for update in split_repository_update(repository) {
1950            session.peer.send(session.connection_id, update)?;
1951        }
1952    }
1953
1954    for language_server in &project.language_servers {
1955        session.peer.send(
1956            session.connection_id,
1957            proto::UpdateLanguageServer {
1958                project_id: project_id.to_proto(),
1959                language_server_id: language_server.id,
1960                variant: Some(
1961                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1962                        proto::LspDiskBasedDiagnosticsUpdated {},
1963                    ),
1964                ),
1965            },
1966        )?;
1967    }
1968
1969    Ok(())
1970}
1971
1972/// Leave someone elses shared project.
1973async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1974    let sender_id = session.connection_id;
1975    let project_id = ProjectId::from_proto(request.project_id);
1976    let db = session.db().await;
1977
1978    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1979    tracing::info!(
1980        %project_id,
1981        "leave project"
1982    );
1983
1984    project_left(project, &session);
1985    if let Some(room) = room {
1986        room_updated(room, &session.peer);
1987    }
1988
1989    Ok(())
1990}
1991
1992/// Updates other participants with changes to the project
1993async fn update_project(
1994    request: proto::UpdateProject,
1995    response: Response<proto::UpdateProject>,
1996    session: Session,
1997) -> Result<()> {
1998    let project_id = ProjectId::from_proto(request.project_id);
1999    let (room, guest_connection_ids) = &*session
2000        .db()
2001        .await
2002        .update_project(project_id, session.connection_id, &request.worktrees)
2003        .await?;
2004    broadcast(
2005        Some(session.connection_id),
2006        guest_connection_ids.iter().copied(),
2007        |connection_id| {
2008            session
2009                .peer
2010                .forward_send(session.connection_id, connection_id, request.clone())
2011        },
2012    );
2013    if let Some(room) = room {
2014        room_updated(room, &session.peer);
2015    }
2016    response.send(proto::Ack {})?;
2017
2018    Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree
2022async fn update_worktree(
2023    request: proto::UpdateWorktree,
2024    response: Response<proto::UpdateWorktree>,
2025    session: Session,
2026) -> Result<()> {
2027    let guest_connection_ids = session
2028        .db()
2029        .await
2030        .update_worktree(&request, session.connection_id)
2031        .await?;
2032
2033    broadcast(
2034        Some(session.connection_id),
2035        guest_connection_ids.iter().copied(),
2036        |connection_id| {
2037            session
2038                .peer
2039                .forward_send(session.connection_id, connection_id, request.clone())
2040        },
2041    );
2042    response.send(proto::Ack {})?;
2043    Ok(())
2044}
2045
2046async fn update_repository(
2047    request: proto::UpdateRepository,
2048    response: Response<proto::UpdateRepository>,
2049    session: Session,
2050) -> Result<()> {
2051    let guest_connection_ids = session
2052        .db()
2053        .await
2054        .update_repository(&request, session.connection_id)
2055        .await?;
2056
2057    broadcast(
2058        Some(session.connection_id),
2059        guest_connection_ids.iter().copied(),
2060        |connection_id| {
2061            session
2062                .peer
2063                .forward_send(session.connection_id, connection_id, request.clone())
2064        },
2065    );
2066    response.send(proto::Ack {})?;
2067    Ok(())
2068}
2069
2070async fn remove_repository(
2071    request: proto::RemoveRepository,
2072    response: Response<proto::RemoveRepository>,
2073    session: Session,
2074) -> Result<()> {
2075    let guest_connection_ids = session
2076        .db()
2077        .await
2078        .remove_repository(&request, session.connection_id)
2079        .await?;
2080
2081    broadcast(
2082        Some(session.connection_id),
2083        guest_connection_ids.iter().copied(),
2084        |connection_id| {
2085            session
2086                .peer
2087                .forward_send(session.connection_id, connection_id, request.clone())
2088        },
2089    );
2090    response.send(proto::Ack {})?;
2091    Ok(())
2092}
2093
2094/// Updates other participants with changes to the diagnostics
2095async fn update_diagnostic_summary(
2096    message: proto::UpdateDiagnosticSummary,
2097    session: Session,
2098) -> Result<()> {
2099    let guest_connection_ids = session
2100        .db()
2101        .await
2102        .update_diagnostic_summary(&message, session.connection_id)
2103        .await?;
2104
2105    broadcast(
2106        Some(session.connection_id),
2107        guest_connection_ids.iter().copied(),
2108        |connection_id| {
2109            session
2110                .peer
2111                .forward_send(session.connection_id, connection_id, message.clone())
2112        },
2113    );
2114
2115    Ok(())
2116}
2117
2118/// Updates other participants with changes to the worktree settings
2119async fn update_worktree_settings(
2120    message: proto::UpdateWorktreeSettings,
2121    session: Session,
2122) -> Result<()> {
2123    let guest_connection_ids = session
2124        .db()
2125        .await
2126        .update_worktree_settings(&message, session.connection_id)
2127        .await?;
2128
2129    broadcast(
2130        Some(session.connection_id),
2131        guest_connection_ids.iter().copied(),
2132        |connection_id| {
2133            session
2134                .peer
2135                .forward_send(session.connection_id, connection_id, message.clone())
2136        },
2137    );
2138
2139    Ok(())
2140}
2141
2142/// Notify other participants that a language server has started.
2143async fn start_language_server(
2144    request: proto::StartLanguageServer,
2145    session: Session,
2146) -> Result<()> {
2147    let guest_connection_ids = session
2148        .db()
2149        .await
2150        .start_language_server(&request, session.connection_id)
2151        .await?;
2152
2153    broadcast(
2154        Some(session.connection_id),
2155        guest_connection_ids.iter().copied(),
2156        |connection_id| {
2157            session
2158                .peer
2159                .forward_send(session.connection_id, connection_id, request.clone())
2160        },
2161    );
2162    Ok(())
2163}
2164
2165/// Notify other participants that a language server has changed.
2166async fn update_language_server(
2167    request: proto::UpdateLanguageServer,
2168    session: Session,
2169) -> Result<()> {
2170    let project_id = ProjectId::from_proto(request.project_id);
2171    let project_connection_ids = session
2172        .db()
2173        .await
2174        .project_connection_ids(project_id, session.connection_id, true)
2175        .await?;
2176    broadcast(
2177        Some(session.connection_id),
2178        project_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, request.clone())
2183        },
2184    );
2185    Ok(())
2186}
2187
2188/// forward a project request to the host. These requests should be read only
2189/// as guests are allowed to send them.
2190async fn forward_read_only_project_request<T>(
2191    request: T,
2192    response: Response<T>,
2193    session: Session,
2194) -> Result<()>
2195where
2196    T: EntityMessage + RequestMessage,
2197{
2198    let project_id = ProjectId::from_proto(request.remote_entity_id());
2199    let host_connection_id = session
2200        .db()
2201        .await
2202        .host_for_read_only_project_request(project_id, session.connection_id)
2203        .await?;
2204    let payload = session
2205        .peer
2206        .forward_request(session.connection_id, host_connection_id, request)
2207        .await?;
2208    response.send(payload)?;
2209    Ok(())
2210}
2211
2212async fn forward_find_search_candidates_request(
2213    request: proto::FindSearchCandidates,
2214    response: Response<proto::FindSearchCandidates>,
2215    session: Session,
2216) -> Result<()> {
2217    let project_id = ProjectId::from_proto(request.remote_entity_id());
2218    let host_connection_id = session
2219        .db()
2220        .await
2221        .host_for_read_only_project_request(project_id, session.connection_id)
2222        .await?;
2223    let payload = session
2224        .peer
2225        .forward_request(session.connection_id, host_connection_id, request)
2226        .await?;
2227    response.send(payload)?;
2228    Ok(())
2229}
2230
2231/// forward a project request to the host. These requests are disallowed
2232/// for guests.
2233async fn forward_mutating_project_request<T>(
2234    request: T,
2235    response: Response<T>,
2236    session: Session,
2237) -> Result<()>
2238where
2239    T: EntityMessage + RequestMessage,
2240{
2241    let project_id = ProjectId::from_proto(request.remote_entity_id());
2242
2243    let host_connection_id = session
2244        .db()
2245        .await
2246        .host_for_mutating_project_request(project_id, session.connection_id)
2247        .await?;
2248    let payload = session
2249        .peer
2250        .forward_request(session.connection_id, host_connection_id, request)
2251        .await?;
2252    response.send(payload)?;
2253    Ok(())
2254}
2255
2256/// Notify other participants that a new buffer has been created
2257async fn create_buffer_for_peer(
2258    request: proto::CreateBufferForPeer,
2259    session: Session,
2260) -> Result<()> {
2261    session
2262        .db()
2263        .await
2264        .check_user_is_project_host(
2265            ProjectId::from_proto(request.project_id),
2266            session.connection_id,
2267        )
2268        .await?;
2269    let peer_id = request.peer_id.context("invalid peer id")?;
2270    session
2271        .peer
2272        .forward_send(session.connection_id, peer_id.into(), request)?;
2273    Ok(())
2274}
2275
2276/// Notify other participants that a buffer has been updated. This is
2277/// allowed for guests as long as the update is limited to selections.
2278async fn update_buffer(
2279    request: proto::UpdateBuffer,
2280    response: Response<proto::UpdateBuffer>,
2281    session: Session,
2282) -> Result<()> {
2283    let project_id = ProjectId::from_proto(request.project_id);
2284    let mut capability = Capability::ReadOnly;
2285
2286    for op in request.operations.iter() {
2287        match op.variant {
2288            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2289            Some(_) => capability = Capability::ReadWrite,
2290        }
2291    }
2292
2293    let host = {
2294        let guard = session
2295            .db()
2296            .await
2297            .connections_for_buffer_update(project_id, session.connection_id, capability)
2298            .await?;
2299
2300        let (host, guests) = &*guard;
2301
2302        broadcast(
2303            Some(session.connection_id),
2304            guests.clone(),
2305            |connection_id| {
2306                session
2307                    .peer
2308                    .forward_send(session.connection_id, connection_id, request.clone())
2309            },
2310        );
2311
2312        *host
2313    };
2314
2315    if host != session.connection_id {
2316        session
2317            .peer
2318            .forward_request(session.connection_id, host, request.clone())
2319            .await?;
2320    }
2321
2322    response.send(proto::Ack {})?;
2323    Ok(())
2324}
2325
2326async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2327    let project_id = ProjectId::from_proto(message.project_id);
2328
2329    let operation = message.operation.as_ref().context("invalid operation")?;
2330    let capability = match operation.variant.as_ref() {
2331        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2332            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2333                match buffer_op.variant {
2334                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2335                        Capability::ReadOnly
2336                    }
2337                    _ => Capability::ReadWrite,
2338                }
2339            } else {
2340                Capability::ReadWrite
2341            }
2342        }
2343        Some(_) => Capability::ReadWrite,
2344        None => Capability::ReadOnly,
2345    };
2346
2347    let guard = session
2348        .db()
2349        .await
2350        .connections_for_buffer_update(project_id, session.connection_id, capability)
2351        .await?;
2352
2353    let (host, guests) = &*guard;
2354
2355    broadcast(
2356        Some(session.connection_id),
2357        guests.iter().chain([host]).copied(),
2358        |connection_id| {
2359            session
2360                .peer
2361                .forward_send(session.connection_id, connection_id, message.clone())
2362        },
2363    );
2364
2365    Ok(())
2366}
2367
2368/// Notify other participants that a project has been updated.
2369async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2370    request: T,
2371    session: Session,
2372) -> Result<()> {
2373    let project_id = ProjectId::from_proto(request.remote_entity_id());
2374    let project_connection_ids = session
2375        .db()
2376        .await
2377        .project_connection_ids(project_id, session.connection_id, false)
2378        .await?;
2379
2380    broadcast(
2381        Some(session.connection_id),
2382        project_connection_ids.iter().copied(),
2383        |connection_id| {
2384            session
2385                .peer
2386                .forward_send(session.connection_id, connection_id, request.clone())
2387        },
2388    );
2389    Ok(())
2390}
2391
2392/// Start following another user in a call.
2393async fn follow(
2394    request: proto::Follow,
2395    response: Response<proto::Follow>,
2396    session: Session,
2397) -> Result<()> {
2398    let room_id = RoomId::from_proto(request.room_id);
2399    let project_id = request.project_id.map(ProjectId::from_proto);
2400    let leader_id = request.leader_id.context("invalid leader id")?.into();
2401    let follower_id = session.connection_id;
2402
2403    session
2404        .db()
2405        .await
2406        .check_room_participants(room_id, leader_id, session.connection_id)
2407        .await?;
2408
2409    let response_payload = session
2410        .peer
2411        .forward_request(session.connection_id, leader_id, request)
2412        .await?;
2413    response.send(response_payload)?;
2414
2415    if let Some(project_id) = project_id {
2416        let room = session
2417            .db()
2418            .await
2419            .follow(room_id, project_id, leader_id, follower_id)
2420            .await?;
2421        room_updated(&room, &session.peer);
2422    }
2423
2424    Ok(())
2425}
2426
2427/// Stop following another user in a call.
2428async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2429    let room_id = RoomId::from_proto(request.room_id);
2430    let project_id = request.project_id.map(ProjectId::from_proto);
2431    let leader_id = request.leader_id.context("invalid leader id")?.into();
2432    let follower_id = session.connection_id;
2433
2434    session
2435        .db()
2436        .await
2437        .check_room_participants(room_id, leader_id, session.connection_id)
2438        .await?;
2439
2440    session
2441        .peer
2442        .forward_send(session.connection_id, leader_id, request)?;
2443
2444    if let Some(project_id) = project_id {
2445        let room = session
2446            .db()
2447            .await
2448            .unfollow(room_id, project_id, leader_id, follower_id)
2449            .await?;
2450        room_updated(&room, &session.peer);
2451    }
2452
2453    Ok(())
2454}
2455
2456/// Notify everyone following you of your current location.
2457async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2458    let room_id = RoomId::from_proto(request.room_id);
2459    let database = session.db.lock().await;
2460
2461    let connection_ids = if let Some(project_id) = request.project_id {
2462        let project_id = ProjectId::from_proto(project_id);
2463        database
2464            .project_connection_ids(project_id, session.connection_id, true)
2465            .await?
2466    } else {
2467        database
2468            .room_connection_ids(room_id, session.connection_id)
2469            .await?
2470    };
2471
2472    // For now, don't send view update messages back to that view's current leader.
2473    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2474        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2475        _ => None,
2476    });
2477
2478    for connection_id in connection_ids.iter().cloned() {
2479        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2480            session
2481                .peer
2482                .forward_send(session.connection_id, connection_id, request.clone())?;
2483        }
2484    }
2485    Ok(())
2486}
2487
2488/// Get public data about users.
2489async fn get_users(
2490    request: proto::GetUsers,
2491    response: Response<proto::GetUsers>,
2492    session: Session,
2493) -> Result<()> {
2494    let user_ids = request
2495        .user_ids
2496        .into_iter()
2497        .map(UserId::from_proto)
2498        .collect();
2499    let users = session
2500        .db()
2501        .await
2502        .get_users_by_ids(user_ids)
2503        .await?
2504        .into_iter()
2505        .map(|user| proto::User {
2506            id: user.id.to_proto(),
2507            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2508            github_login: user.github_login,
2509            name: user.name,
2510        })
2511        .collect();
2512    response.send(proto::UsersResponse { users })?;
2513    Ok(())
2514}
2515
2516/// Search for users (to invite) buy Github login
2517async fn fuzzy_search_users(
2518    request: proto::FuzzySearchUsers,
2519    response: Response<proto::FuzzySearchUsers>,
2520    session: Session,
2521) -> Result<()> {
2522    let query = request.query;
2523    let users = match query.len() {
2524        0 => vec![],
2525        1 | 2 => session
2526            .db()
2527            .await
2528            .get_user_by_github_login(&query)
2529            .await?
2530            .into_iter()
2531            .collect(),
2532        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2533    };
2534    let users = users
2535        .into_iter()
2536        .filter(|user| user.id != session.user_id())
2537        .map(|user| proto::User {
2538            id: user.id.to_proto(),
2539            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2540            github_login: user.github_login,
2541            name: user.name,
2542        })
2543        .collect();
2544    response.send(proto::UsersResponse { users })?;
2545    Ok(())
2546}
2547
2548/// Send a contact request to another user.
2549async fn request_contact(
2550    request: proto::RequestContact,
2551    response: Response<proto::RequestContact>,
2552    session: Session,
2553) -> Result<()> {
2554    let requester_id = session.user_id();
2555    let responder_id = UserId::from_proto(request.responder_id);
2556    if requester_id == responder_id {
2557        return Err(anyhow!("cannot add yourself as a contact"))?;
2558    }
2559
2560    let notifications = session
2561        .db()
2562        .await
2563        .send_contact_request(requester_id, responder_id)
2564        .await?;
2565
2566    // Update outgoing contact requests of requester
2567    let mut update = proto::UpdateContacts::default();
2568    update.outgoing_requests.push(responder_id.to_proto());
2569    for connection_id in session
2570        .connection_pool()
2571        .await
2572        .user_connection_ids(requester_id)
2573    {
2574        session.peer.send(connection_id, update.clone())?;
2575    }
2576
2577    // Update incoming contact requests of responder
2578    let mut update = proto::UpdateContacts::default();
2579    update
2580        .incoming_requests
2581        .push(proto::IncomingContactRequest {
2582            requester_id: requester_id.to_proto(),
2583        });
2584    let connection_pool = session.connection_pool().await;
2585    for connection_id in connection_pool.user_connection_ids(responder_id) {
2586        session.peer.send(connection_id, update.clone())?;
2587    }
2588
2589    send_notifications(&connection_pool, &session.peer, notifications);
2590
2591    response.send(proto::Ack {})?;
2592    Ok(())
2593}
2594
2595/// Accept or decline a contact request
2596async fn respond_to_contact_request(
2597    request: proto::RespondToContactRequest,
2598    response: Response<proto::RespondToContactRequest>,
2599    session: Session,
2600) -> Result<()> {
2601    let responder_id = session.user_id();
2602    let requester_id = UserId::from_proto(request.requester_id);
2603    let db = session.db().await;
2604    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2605        db.dismiss_contact_notification(responder_id, requester_id)
2606            .await?;
2607    } else {
2608        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2609
2610        let notifications = db
2611            .respond_to_contact_request(responder_id, requester_id, accept)
2612            .await?;
2613        let requester_busy = db.is_user_busy(requester_id).await?;
2614        let responder_busy = db.is_user_busy(responder_id).await?;
2615
2616        let pool = session.connection_pool().await;
2617        // Update responder with new contact
2618        let mut update = proto::UpdateContacts::default();
2619        if accept {
2620            update
2621                .contacts
2622                .push(contact_for_user(requester_id, requester_busy, &pool));
2623        }
2624        update
2625            .remove_incoming_requests
2626            .push(requester_id.to_proto());
2627        for connection_id in pool.user_connection_ids(responder_id) {
2628            session.peer.send(connection_id, update.clone())?;
2629        }
2630
2631        // Update requester with new contact
2632        let mut update = proto::UpdateContacts::default();
2633        if accept {
2634            update
2635                .contacts
2636                .push(contact_for_user(responder_id, responder_busy, &pool));
2637        }
2638        update
2639            .remove_outgoing_requests
2640            .push(responder_id.to_proto());
2641
2642        for connection_id in pool.user_connection_ids(requester_id) {
2643            session.peer.send(connection_id, update.clone())?;
2644        }
2645
2646        send_notifications(&pool, &session.peer, notifications);
2647    }
2648
2649    response.send(proto::Ack {})?;
2650    Ok(())
2651}
2652
2653/// Remove a contact.
2654async fn remove_contact(
2655    request: proto::RemoveContact,
2656    response: Response<proto::RemoveContact>,
2657    session: Session,
2658) -> Result<()> {
2659    let requester_id = session.user_id();
2660    let responder_id = UserId::from_proto(request.user_id);
2661    let db = session.db().await;
2662    let (contact_accepted, deleted_notification_id) =
2663        db.remove_contact(requester_id, responder_id).await?;
2664
2665    let pool = session.connection_pool().await;
2666    // Update outgoing contact requests of requester
2667    let mut update = proto::UpdateContacts::default();
2668    if contact_accepted {
2669        update.remove_contacts.push(responder_id.to_proto());
2670    } else {
2671        update
2672            .remove_outgoing_requests
2673            .push(responder_id.to_proto());
2674    }
2675    for connection_id in pool.user_connection_ids(requester_id) {
2676        session.peer.send(connection_id, update.clone())?;
2677    }
2678
2679    // Update incoming contact requests of responder
2680    let mut update = proto::UpdateContacts::default();
2681    if contact_accepted {
2682        update.remove_contacts.push(requester_id.to_proto());
2683    } else {
2684        update
2685            .remove_incoming_requests
2686            .push(requester_id.to_proto());
2687    }
2688    for connection_id in pool.user_connection_ids(responder_id) {
2689        session.peer.send(connection_id, update.clone())?;
2690        if let Some(notification_id) = deleted_notification_id {
2691            session.peer.send(
2692                connection_id,
2693                proto::DeleteNotification {
2694                    notification_id: notification_id.to_proto(),
2695                },
2696            )?;
2697        }
2698    }
2699
2700    response.send(proto::Ack {})?;
2701    Ok(())
2702}
2703
2704fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2705    version.0.minor() < 139
2706}
2707
2708async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2709    if is_staff {
2710        return Ok(proto::Plan::ZedPro);
2711    }
2712
2713    let subscription = db.get_active_billing_subscription(user_id).await?;
2714    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2715
2716    let plan = if let Some(subscription_kind) = subscription_kind {
2717        match subscription_kind {
2718            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2719            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2720            SubscriptionKind::ZedFree => proto::Plan::Free,
2721        }
2722    } else {
2723        proto::Plan::Free
2724    };
2725
2726    Ok(plan)
2727}
2728
2729async fn make_update_user_plan_message(
2730    user: &User,
2731    is_staff: bool,
2732    db: &Arc<Database>,
2733    llm_db: Option<Arc<LlmDatabase>>,
2734) -> Result<proto::UpdateUserPlan> {
2735    let feature_flags = db.get_user_flags(user.id).await?;
2736    let plan = current_plan(db, user.id, is_staff).await?;
2737    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2738    let billing_preferences = db.get_billing_preferences(user.id).await?;
2739
2740    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2741        let subscription = db.get_active_billing_subscription(user.id).await?;
2742
2743        let subscription_period =
2744            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2745
2746        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2747            llm_db
2748                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2749                .await?
2750        } else {
2751            None
2752        };
2753
2754        (subscription_period, usage)
2755    } else {
2756        (None, None)
2757    };
2758
2759    let account_too_young =
2760        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2761
2762    Ok(proto::UpdateUserPlan {
2763        plan: plan.into(),
2764        trial_started_at: billing_customer
2765            .as_ref()
2766            .and_then(|billing_customer| billing_customer.trial_started_at)
2767            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2768        is_usage_based_billing_enabled: if is_staff {
2769            Some(true)
2770        } else {
2771            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2772        },
2773        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2774            proto::SubscriptionPeriod {
2775                started_at: started_at.timestamp() as u64,
2776                ended_at: ended_at.timestamp() as u64,
2777            }
2778        }),
2779        account_too_young: Some(account_too_young),
2780        has_overdue_invoices: billing_customer
2781            .map(|billing_customer| billing_customer.has_overdue_invoices),
2782        usage: usage.map(|usage| {
2783            let plan = match plan {
2784                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2785                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2786                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2787            };
2788
2789            let model_requests_limit = match plan.model_requests_limit() {
2790                zed_llm_client::UsageLimit::Limited(limit) => {
2791                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2792                        && feature_flags
2793                            .iter()
2794                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2795                    {
2796                        1_000
2797                    } else {
2798                        limit
2799                    };
2800
2801                    zed_llm_client::UsageLimit::Limited(limit)
2802                }
2803                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2804            };
2805
2806            proto::SubscriptionUsage {
2807                model_requests_usage_amount: usage.model_requests as u32,
2808                model_requests_usage_limit: Some(proto::UsageLimit {
2809                    variant: Some(match model_requests_limit {
2810                        zed_llm_client::UsageLimit::Limited(limit) => {
2811                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2812                                limit: limit as u32,
2813                            })
2814                        }
2815                        zed_llm_client::UsageLimit::Unlimited => {
2816                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2817                        }
2818                    }),
2819                }),
2820                edit_predictions_usage_amount: usage.edit_predictions as u32,
2821                edit_predictions_usage_limit: Some(proto::UsageLimit {
2822                    variant: Some(match plan.edit_predictions_limit() {
2823                        zed_llm_client::UsageLimit::Limited(limit) => {
2824                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2825                                limit: limit as u32,
2826                            })
2827                        }
2828                        zed_llm_client::UsageLimit::Unlimited => {
2829                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2830                        }
2831                    }),
2832                }),
2833            }
2834        }),
2835    })
2836}
2837
2838async fn update_user_plan(session: &Session) -> Result<()> {
2839    let db = session.db().await;
2840
2841    let update_user_plan = make_update_user_plan_message(
2842        session.principal.user(),
2843        session.is_staff(),
2844        &db.0,
2845        session.app_state.llm_db.clone(),
2846    )
2847    .await?;
2848
2849    session
2850        .peer
2851        .send(session.connection_id, update_user_plan)
2852        .trace_err();
2853
2854    Ok(())
2855}
2856
2857async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2858    subscribe_user_to_channels(session.user_id(), &session).await?;
2859    Ok(())
2860}
2861
2862async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2863    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2864    let mut pool = session.connection_pool().await;
2865    for membership in &channels_for_user.channel_memberships {
2866        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2867    }
2868    session.peer.send(
2869        session.connection_id,
2870        build_update_user_channels(&channels_for_user),
2871    )?;
2872    session.peer.send(
2873        session.connection_id,
2874        build_channels_update(channels_for_user),
2875    )?;
2876    Ok(())
2877}
2878
2879/// Creates a new channel.
2880async fn create_channel(
2881    request: proto::CreateChannel,
2882    response: Response<proto::CreateChannel>,
2883    session: Session,
2884) -> Result<()> {
2885    let db = session.db().await;
2886
2887    let parent_id = request.parent_id.map(ChannelId::from_proto);
2888    let (channel, membership) = db
2889        .create_channel(&request.name, parent_id, session.user_id())
2890        .await?;
2891
2892    let root_id = channel.root_id();
2893    let channel = Channel::from_model(channel);
2894
2895    response.send(proto::CreateChannelResponse {
2896        channel: Some(channel.to_proto()),
2897        parent_id: request.parent_id,
2898    })?;
2899
2900    let mut connection_pool = session.connection_pool().await;
2901    if let Some(membership) = membership {
2902        connection_pool.subscribe_to_channel(
2903            membership.user_id,
2904            membership.channel_id,
2905            membership.role,
2906        );
2907        let update = proto::UpdateUserChannels {
2908            channel_memberships: vec![proto::ChannelMembership {
2909                channel_id: membership.channel_id.to_proto(),
2910                role: membership.role.into(),
2911            }],
2912            ..Default::default()
2913        };
2914        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2915            session.peer.send(connection_id, update.clone())?;
2916        }
2917    }
2918
2919    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2920        if !role.can_see_channel(channel.visibility) {
2921            continue;
2922        }
2923
2924        let update = proto::UpdateChannels {
2925            channels: vec![channel.to_proto()],
2926            ..Default::default()
2927        };
2928        session.peer.send(connection_id, update.clone())?;
2929    }
2930
2931    Ok(())
2932}
2933
2934/// Delete a channel
2935async fn delete_channel(
2936    request: proto::DeleteChannel,
2937    response: Response<proto::DeleteChannel>,
2938    session: Session,
2939) -> Result<()> {
2940    let db = session.db().await;
2941
2942    let channel_id = request.channel_id;
2943    let (root_channel, removed_channels) = db
2944        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2945        .await?;
2946    response.send(proto::Ack {})?;
2947
2948    // Notify members of removed channels
2949    let mut update = proto::UpdateChannels::default();
2950    update
2951        .delete_channels
2952        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2953
2954    let connection_pool = session.connection_pool().await;
2955    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2956        session.peer.send(connection_id, update.clone())?;
2957    }
2958
2959    Ok(())
2960}
2961
2962/// Invite someone to join a channel.
2963async fn invite_channel_member(
2964    request: proto::InviteChannelMember,
2965    response: Response<proto::InviteChannelMember>,
2966    session: Session,
2967) -> Result<()> {
2968    let db = session.db().await;
2969    let channel_id = ChannelId::from_proto(request.channel_id);
2970    let invitee_id = UserId::from_proto(request.user_id);
2971    let InviteMemberResult {
2972        channel,
2973        notifications,
2974    } = db
2975        .invite_channel_member(
2976            channel_id,
2977            invitee_id,
2978            session.user_id(),
2979            request.role().into(),
2980        )
2981        .await?;
2982
2983    let update = proto::UpdateChannels {
2984        channel_invitations: vec![channel.to_proto()],
2985        ..Default::default()
2986    };
2987
2988    let connection_pool = session.connection_pool().await;
2989    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2990        session.peer.send(connection_id, update.clone())?;
2991    }
2992
2993    send_notifications(&connection_pool, &session.peer, notifications);
2994
2995    response.send(proto::Ack {})?;
2996    Ok(())
2997}
2998
2999/// remove someone from a channel
3000async fn remove_channel_member(
3001    request: proto::RemoveChannelMember,
3002    response: Response<proto::RemoveChannelMember>,
3003    session: Session,
3004) -> Result<()> {
3005    let db = session.db().await;
3006    let channel_id = ChannelId::from_proto(request.channel_id);
3007    let member_id = UserId::from_proto(request.user_id);
3008
3009    let RemoveChannelMemberResult {
3010        membership_update,
3011        notification_id,
3012    } = db
3013        .remove_channel_member(channel_id, member_id, session.user_id())
3014        .await?;
3015
3016    let mut connection_pool = session.connection_pool().await;
3017    notify_membership_updated(
3018        &mut connection_pool,
3019        membership_update,
3020        member_id,
3021        &session.peer,
3022    );
3023    for connection_id in connection_pool.user_connection_ids(member_id) {
3024        if let Some(notification_id) = notification_id {
3025            session
3026                .peer
3027                .send(
3028                    connection_id,
3029                    proto::DeleteNotification {
3030                        notification_id: notification_id.to_proto(),
3031                    },
3032                )
3033                .trace_err();
3034        }
3035    }
3036
3037    response.send(proto::Ack {})?;
3038    Ok(())
3039}
3040
3041/// Toggle the channel between public and private.
3042/// Care is taken to maintain the invariant that public channels only descend from public channels,
3043/// (though members-only channels can appear at any point in the hierarchy).
3044async fn set_channel_visibility(
3045    request: proto::SetChannelVisibility,
3046    response: Response<proto::SetChannelVisibility>,
3047    session: Session,
3048) -> Result<()> {
3049    let db = session.db().await;
3050    let channel_id = ChannelId::from_proto(request.channel_id);
3051    let visibility = request.visibility().into();
3052
3053    let channel_model = db
3054        .set_channel_visibility(channel_id, visibility, session.user_id())
3055        .await?;
3056    let root_id = channel_model.root_id();
3057    let channel = Channel::from_model(channel_model);
3058
3059    let mut connection_pool = session.connection_pool().await;
3060    for (user_id, role) in connection_pool
3061        .channel_user_ids(root_id)
3062        .collect::<Vec<_>>()
3063        .into_iter()
3064    {
3065        let update = if role.can_see_channel(channel.visibility) {
3066            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3067            proto::UpdateChannels {
3068                channels: vec![channel.to_proto()],
3069                ..Default::default()
3070            }
3071        } else {
3072            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3073            proto::UpdateChannels {
3074                delete_channels: vec![channel.id.to_proto()],
3075                ..Default::default()
3076            }
3077        };
3078
3079        for connection_id in connection_pool.user_connection_ids(user_id) {
3080            session.peer.send(connection_id, update.clone())?;
3081        }
3082    }
3083
3084    response.send(proto::Ack {})?;
3085    Ok(())
3086}
3087
3088/// Alter the role for a user in the channel.
3089async fn set_channel_member_role(
3090    request: proto::SetChannelMemberRole,
3091    response: Response<proto::SetChannelMemberRole>,
3092    session: Session,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096    let member_id = UserId::from_proto(request.user_id);
3097    let result = db
3098        .set_channel_member_role(
3099            channel_id,
3100            session.user_id(),
3101            member_id,
3102            request.role().into(),
3103        )
3104        .await?;
3105
3106    match result {
3107        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3108            let mut connection_pool = session.connection_pool().await;
3109            notify_membership_updated(
3110                &mut connection_pool,
3111                membership_update,
3112                member_id,
3113                &session.peer,
3114            )
3115        }
3116        db::SetMemberRoleResult::InviteUpdated(channel) => {
3117            let update = proto::UpdateChannels {
3118                channel_invitations: vec![channel.to_proto()],
3119                ..Default::default()
3120            };
3121
3122            for connection_id in session
3123                .connection_pool()
3124                .await
3125                .user_connection_ids(member_id)
3126            {
3127                session.peer.send(connection_id, update.clone())?;
3128            }
3129        }
3130    }
3131
3132    response.send(proto::Ack {})?;
3133    Ok(())
3134}
3135
3136/// Change the name of a channel
3137async fn rename_channel(
3138    request: proto::RenameChannel,
3139    response: Response<proto::RenameChannel>,
3140    session: Session,
3141) -> Result<()> {
3142    let db = session.db().await;
3143    let channel_id = ChannelId::from_proto(request.channel_id);
3144    let channel_model = db
3145        .rename_channel(channel_id, session.user_id(), &request.name)
3146        .await?;
3147    let root_id = channel_model.root_id();
3148    let channel = Channel::from_model(channel_model);
3149
3150    response.send(proto::RenameChannelResponse {
3151        channel: Some(channel.to_proto()),
3152    })?;
3153
3154    let connection_pool = session.connection_pool().await;
3155    let update = proto::UpdateChannels {
3156        channels: vec![channel.to_proto()],
3157        ..Default::default()
3158    };
3159    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3160        if role.can_see_channel(channel.visibility) {
3161            session.peer.send(connection_id, update.clone())?;
3162        }
3163    }
3164
3165    Ok(())
3166}
3167
3168/// Move a channel to a new parent.
3169async fn move_channel(
3170    request: proto::MoveChannel,
3171    response: Response<proto::MoveChannel>,
3172    session: Session,
3173) -> Result<()> {
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175    let to = ChannelId::from_proto(request.to);
3176
3177    let (root_id, channels) = session
3178        .db()
3179        .await
3180        .move_channel(channel_id, to, session.user_id())
3181        .await?;
3182
3183    let connection_pool = session.connection_pool().await;
3184    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3185        let channels = channels
3186            .iter()
3187            .filter_map(|channel| {
3188                if role.can_see_channel(channel.visibility) {
3189                    Some(channel.to_proto())
3190                } else {
3191                    None
3192                }
3193            })
3194            .collect::<Vec<_>>();
3195        if channels.is_empty() {
3196            continue;
3197        }
3198
3199        let update = proto::UpdateChannels {
3200            channels,
3201            ..Default::default()
3202        };
3203
3204        session.peer.send(connection_id, update.clone())?;
3205    }
3206
3207    response.send(Ack {})?;
3208    Ok(())
3209}
3210
3211/// Get the list of channel members
3212async fn get_channel_members(
3213    request: proto::GetChannelMembers,
3214    response: Response<proto::GetChannelMembers>,
3215    session: Session,
3216) -> Result<()> {
3217    let db = session.db().await;
3218    let channel_id = ChannelId::from_proto(request.channel_id);
3219    let limit = if request.limit == 0 {
3220        u16::MAX as u64
3221    } else {
3222        request.limit
3223    };
3224    let (members, users) = db
3225        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3226        .await?;
3227    response.send(proto::GetChannelMembersResponse { members, users })?;
3228    Ok(())
3229}
3230
3231/// Accept or decline a channel invitation.
3232async fn respond_to_channel_invite(
3233    request: proto::RespondToChannelInvite,
3234    response: Response<proto::RespondToChannelInvite>,
3235    session: Session,
3236) -> Result<()> {
3237    let db = session.db().await;
3238    let channel_id = ChannelId::from_proto(request.channel_id);
3239    let RespondToChannelInvite {
3240        membership_update,
3241        notifications,
3242    } = db
3243        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3244        .await?;
3245
3246    let mut connection_pool = session.connection_pool().await;
3247    if let Some(membership_update) = membership_update {
3248        notify_membership_updated(
3249            &mut connection_pool,
3250            membership_update,
3251            session.user_id(),
3252            &session.peer,
3253        );
3254    } else {
3255        let update = proto::UpdateChannels {
3256            remove_channel_invitations: vec![channel_id.to_proto()],
3257            ..Default::default()
3258        };
3259
3260        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3261            session.peer.send(connection_id, update.clone())?;
3262        }
3263    };
3264
3265    send_notifications(&connection_pool, &session.peer, notifications);
3266
3267    response.send(proto::Ack {})?;
3268
3269    Ok(())
3270}
3271
3272/// Join the channels' room
3273async fn join_channel(
3274    request: proto::JoinChannel,
3275    response: Response<proto::JoinChannel>,
3276    session: Session,
3277) -> Result<()> {
3278    let channel_id = ChannelId::from_proto(request.channel_id);
3279    join_channel_internal(channel_id, Box::new(response), session).await
3280}
3281
3282trait JoinChannelInternalResponse {
3283    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3284}
3285impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3286    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3287        Response::<proto::JoinChannel>::send(self, result)
3288    }
3289}
3290impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3291    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3292        Response::<proto::JoinRoom>::send(self, result)
3293    }
3294}
3295
3296async fn join_channel_internal(
3297    channel_id: ChannelId,
3298    response: Box<impl JoinChannelInternalResponse>,
3299    session: Session,
3300) -> Result<()> {
3301    let joined_room = {
3302        let mut db = session.db().await;
3303        // If zed quits without leaving the room, and the user re-opens zed before the
3304        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3305        // room they were in.
3306        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3307            tracing::info!(
3308                stale_connection_id = %connection,
3309                "cleaning up stale connection",
3310            );
3311            drop(db);
3312            leave_room_for_session(&session, connection).await?;
3313            db = session.db().await;
3314        }
3315
3316        let (joined_room, membership_updated, role) = db
3317            .join_channel(channel_id, session.user_id(), session.connection_id)
3318            .await?;
3319
3320        let live_kit_connection_info =
3321            session
3322                .app_state
3323                .livekit_client
3324                .as_ref()
3325                .and_then(|live_kit| {
3326                    let (can_publish, token) = if role == ChannelRole::Guest {
3327                        (
3328                            false,
3329                            live_kit
3330                                .guest_token(
3331                                    &joined_room.room.livekit_room,
3332                                    &session.user_id().to_string(),
3333                                )
3334                                .trace_err()?,
3335                        )
3336                    } else {
3337                        (
3338                            true,
3339                            live_kit
3340                                .room_token(
3341                                    &joined_room.room.livekit_room,
3342                                    &session.user_id().to_string(),
3343                                )
3344                                .trace_err()?,
3345                        )
3346                    };
3347
3348                    Some(LiveKitConnectionInfo {
3349                        server_url: live_kit.url().into(),
3350                        token,
3351                        can_publish,
3352                    })
3353                });
3354
3355        response.send(proto::JoinRoomResponse {
3356            room: Some(joined_room.room.clone()),
3357            channel_id: joined_room
3358                .channel
3359                .as_ref()
3360                .map(|channel| channel.id.to_proto()),
3361            live_kit_connection_info,
3362        })?;
3363
3364        let mut connection_pool = session.connection_pool().await;
3365        if let Some(membership_updated) = membership_updated {
3366            notify_membership_updated(
3367                &mut connection_pool,
3368                membership_updated,
3369                session.user_id(),
3370                &session.peer,
3371            );
3372        }
3373
3374        room_updated(&joined_room.room, &session.peer);
3375
3376        joined_room
3377    };
3378
3379    channel_updated(
3380        &joined_room.channel.context("channel not returned")?,
3381        &joined_room.room,
3382        &session.peer,
3383        &*session.connection_pool().await,
3384    );
3385
3386    update_user_contacts(session.user_id(), &session).await?;
3387    Ok(())
3388}
3389
3390/// Start editing the channel notes
3391async fn join_channel_buffer(
3392    request: proto::JoinChannelBuffer,
3393    response: Response<proto::JoinChannelBuffer>,
3394    session: Session,
3395) -> Result<()> {
3396    let db = session.db().await;
3397    let channel_id = ChannelId::from_proto(request.channel_id);
3398
3399    let open_response = db
3400        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3401        .await?;
3402
3403    let collaborators = open_response.collaborators.clone();
3404    response.send(open_response)?;
3405
3406    let update = UpdateChannelBufferCollaborators {
3407        channel_id: channel_id.to_proto(),
3408        collaborators: collaborators.clone(),
3409    };
3410    channel_buffer_updated(
3411        session.connection_id,
3412        collaborators
3413            .iter()
3414            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3415        &update,
3416        &session.peer,
3417    );
3418
3419    Ok(())
3420}
3421
3422/// Edit the channel notes
3423async fn update_channel_buffer(
3424    request: proto::UpdateChannelBuffer,
3425    session: Session,
3426) -> Result<()> {
3427    let db = session.db().await;
3428    let channel_id = ChannelId::from_proto(request.channel_id);
3429
3430    let (collaborators, epoch, version) = db
3431        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3432        .await?;
3433
3434    channel_buffer_updated(
3435        session.connection_id,
3436        collaborators.clone(),
3437        &proto::UpdateChannelBuffer {
3438            channel_id: channel_id.to_proto(),
3439            operations: request.operations,
3440        },
3441        &session.peer,
3442    );
3443
3444    let pool = &*session.connection_pool().await;
3445
3446    let non_collaborators =
3447        pool.channel_connection_ids(channel_id)
3448            .filter_map(|(connection_id, _)| {
3449                if collaborators.contains(&connection_id) {
3450                    None
3451                } else {
3452                    Some(connection_id)
3453                }
3454            });
3455
3456    broadcast(None, non_collaborators, |peer_id| {
3457        session.peer.send(
3458            peer_id,
3459            proto::UpdateChannels {
3460                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3461                    channel_id: channel_id.to_proto(),
3462                    epoch: epoch as u64,
3463                    version: version.clone(),
3464                }],
3465                ..Default::default()
3466            },
3467        )
3468    });
3469
3470    Ok(())
3471}
3472
3473/// Rejoin the channel notes after a connection blip
3474async fn rejoin_channel_buffers(
3475    request: proto::RejoinChannelBuffers,
3476    response: Response<proto::RejoinChannelBuffers>,
3477    session: Session,
3478) -> Result<()> {
3479    let db = session.db().await;
3480    let buffers = db
3481        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3482        .await?;
3483
3484    for rejoined_buffer in &buffers {
3485        let collaborators_to_notify = rejoined_buffer
3486            .buffer
3487            .collaborators
3488            .iter()
3489            .filter_map(|c| Some(c.peer_id?.into()));
3490        channel_buffer_updated(
3491            session.connection_id,
3492            collaborators_to_notify,
3493            &proto::UpdateChannelBufferCollaborators {
3494                channel_id: rejoined_buffer.buffer.channel_id,
3495                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3496            },
3497            &session.peer,
3498        );
3499    }
3500
3501    response.send(proto::RejoinChannelBuffersResponse {
3502        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3503    })?;
3504
3505    Ok(())
3506}
3507
3508/// Stop editing the channel notes
3509async fn leave_channel_buffer(
3510    request: proto::LeaveChannelBuffer,
3511    response: Response<proto::LeaveChannelBuffer>,
3512    session: Session,
3513) -> Result<()> {
3514    let db = session.db().await;
3515    let channel_id = ChannelId::from_proto(request.channel_id);
3516
3517    let left_buffer = db
3518        .leave_channel_buffer(channel_id, session.connection_id)
3519        .await?;
3520
3521    response.send(Ack {})?;
3522
3523    channel_buffer_updated(
3524        session.connection_id,
3525        left_buffer.connections,
3526        &proto::UpdateChannelBufferCollaborators {
3527            channel_id: channel_id.to_proto(),
3528            collaborators: left_buffer.collaborators,
3529        },
3530        &session.peer,
3531    );
3532
3533    Ok(())
3534}
3535
3536fn channel_buffer_updated<T: EnvelopedMessage>(
3537    sender_id: ConnectionId,
3538    collaborators: impl IntoIterator<Item = ConnectionId>,
3539    message: &T,
3540    peer: &Peer,
3541) {
3542    broadcast(Some(sender_id), collaborators, |peer_id| {
3543        peer.send(peer_id, message.clone())
3544    });
3545}
3546
3547fn send_notifications(
3548    connection_pool: &ConnectionPool,
3549    peer: &Peer,
3550    notifications: db::NotificationBatch,
3551) {
3552    for (user_id, notification) in notifications {
3553        for connection_id in connection_pool.user_connection_ids(user_id) {
3554            if let Err(error) = peer.send(
3555                connection_id,
3556                proto::AddNotification {
3557                    notification: Some(notification.clone()),
3558                },
3559            ) {
3560                tracing::error!(
3561                    "failed to send notification to {:?} {}",
3562                    connection_id,
3563                    error
3564                );
3565            }
3566        }
3567    }
3568}
3569
3570/// Send a message to the channel
3571async fn send_channel_message(
3572    request: proto::SendChannelMessage,
3573    response: Response<proto::SendChannelMessage>,
3574    session: Session,
3575) -> Result<()> {
3576    // Validate the message body.
3577    let body = request.body.trim().to_string();
3578    if body.len() > MAX_MESSAGE_LEN {
3579        return Err(anyhow!("message is too long"))?;
3580    }
3581    if body.is_empty() {
3582        return Err(anyhow!("message can't be blank"))?;
3583    }
3584
3585    // TODO: adjust mentions if body is trimmed
3586
3587    let timestamp = OffsetDateTime::now_utc();
3588    let nonce = request.nonce.context("nonce can't be blank")?;
3589
3590    let channel_id = ChannelId::from_proto(request.channel_id);
3591    let CreatedChannelMessage {
3592        message_id,
3593        participant_connection_ids,
3594        notifications,
3595    } = session
3596        .db()
3597        .await
3598        .create_channel_message(
3599            channel_id,
3600            session.user_id(),
3601            &body,
3602            &request.mentions,
3603            timestamp,
3604            nonce.clone().into(),
3605            request.reply_to_message_id.map(MessageId::from_proto),
3606        )
3607        .await?;
3608
3609    let message = proto::ChannelMessage {
3610        sender_id: session.user_id().to_proto(),
3611        id: message_id.to_proto(),
3612        body,
3613        mentions: request.mentions,
3614        timestamp: timestamp.unix_timestamp() as u64,
3615        nonce: Some(nonce),
3616        reply_to_message_id: request.reply_to_message_id,
3617        edited_at: None,
3618    };
3619    broadcast(
3620        Some(session.connection_id),
3621        participant_connection_ids.clone(),
3622        |connection| {
3623            session.peer.send(
3624                connection,
3625                proto::ChannelMessageSent {
3626                    channel_id: channel_id.to_proto(),
3627                    message: Some(message.clone()),
3628                },
3629            )
3630        },
3631    );
3632    response.send(proto::SendChannelMessageResponse {
3633        message: Some(message),
3634    })?;
3635
3636    let pool = &*session.connection_pool().await;
3637    let non_participants =
3638        pool.channel_connection_ids(channel_id)
3639            .filter_map(|(connection_id, _)| {
3640                if participant_connection_ids.contains(&connection_id) {
3641                    None
3642                } else {
3643                    Some(connection_id)
3644                }
3645            });
3646    broadcast(None, non_participants, |peer_id| {
3647        session.peer.send(
3648            peer_id,
3649            proto::UpdateChannels {
3650                latest_channel_message_ids: vec![proto::ChannelMessageId {
3651                    channel_id: channel_id.to_proto(),
3652                    message_id: message_id.to_proto(),
3653                }],
3654                ..Default::default()
3655            },
3656        )
3657    });
3658    send_notifications(pool, &session.peer, notifications);
3659
3660    Ok(())
3661}
3662
3663/// Delete a channel message
3664async fn remove_channel_message(
3665    request: proto::RemoveChannelMessage,
3666    response: Response<proto::RemoveChannelMessage>,
3667    session: Session,
3668) -> Result<()> {
3669    let channel_id = ChannelId::from_proto(request.channel_id);
3670    let message_id = MessageId::from_proto(request.message_id);
3671    let (connection_ids, existing_notification_ids) = session
3672        .db()
3673        .await
3674        .remove_channel_message(channel_id, message_id, session.user_id())
3675        .await?;
3676
3677    broadcast(
3678        Some(session.connection_id),
3679        connection_ids,
3680        move |connection| {
3681            session.peer.send(connection, request.clone())?;
3682
3683            for notification_id in &existing_notification_ids {
3684                session.peer.send(
3685                    connection,
3686                    proto::DeleteNotification {
3687                        notification_id: (*notification_id).to_proto(),
3688                    },
3689                )?;
3690            }
3691
3692            Ok(())
3693        },
3694    );
3695    response.send(proto::Ack {})?;
3696    Ok(())
3697}
3698
3699async fn update_channel_message(
3700    request: proto::UpdateChannelMessage,
3701    response: Response<proto::UpdateChannelMessage>,
3702    session: Session,
3703) -> Result<()> {
3704    let channel_id = ChannelId::from_proto(request.channel_id);
3705    let message_id = MessageId::from_proto(request.message_id);
3706    let updated_at = OffsetDateTime::now_utc();
3707    let UpdatedChannelMessage {
3708        message_id,
3709        participant_connection_ids,
3710        notifications,
3711        reply_to_message_id,
3712        timestamp,
3713        deleted_mention_notification_ids,
3714        updated_mention_notifications,
3715    } = session
3716        .db()
3717        .await
3718        .update_channel_message(
3719            channel_id,
3720            message_id,
3721            session.user_id(),
3722            request.body.as_str(),
3723            &request.mentions,
3724            updated_at,
3725        )
3726        .await?;
3727
3728    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3729
3730    let message = proto::ChannelMessage {
3731        sender_id: session.user_id().to_proto(),
3732        id: message_id.to_proto(),
3733        body: request.body.clone(),
3734        mentions: request.mentions.clone(),
3735        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3736        nonce: Some(nonce),
3737        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3738        edited_at: Some(updated_at.unix_timestamp() as u64),
3739    };
3740
3741    response.send(proto::Ack {})?;
3742
3743    let pool = &*session.connection_pool().await;
3744    broadcast(
3745        Some(session.connection_id),
3746        participant_connection_ids,
3747        |connection| {
3748            session.peer.send(
3749                connection,
3750                proto::ChannelMessageUpdate {
3751                    channel_id: channel_id.to_proto(),
3752                    message: Some(message.clone()),
3753                },
3754            )?;
3755
3756            for notification_id in &deleted_mention_notification_ids {
3757                session.peer.send(
3758                    connection,
3759                    proto::DeleteNotification {
3760                        notification_id: (*notification_id).to_proto(),
3761                    },
3762                )?;
3763            }
3764
3765            for notification in &updated_mention_notifications {
3766                session.peer.send(
3767                    connection,
3768                    proto::UpdateNotification {
3769                        notification: Some(notification.clone()),
3770                    },
3771                )?;
3772            }
3773
3774            Ok(())
3775        },
3776    );
3777
3778    send_notifications(pool, &session.peer, notifications);
3779
3780    Ok(())
3781}
3782
3783/// Mark a channel message as read
3784async fn acknowledge_channel_message(
3785    request: proto::AckChannelMessage,
3786    session: Session,
3787) -> Result<()> {
3788    let channel_id = ChannelId::from_proto(request.channel_id);
3789    let message_id = MessageId::from_proto(request.message_id);
3790    let notifications = session
3791        .db()
3792        .await
3793        .observe_channel_message(channel_id, session.user_id(), message_id)
3794        .await?;
3795    send_notifications(
3796        &*session.connection_pool().await,
3797        &session.peer,
3798        notifications,
3799    );
3800    Ok(())
3801}
3802
3803/// Mark a buffer version as synced
3804async fn acknowledge_buffer_version(
3805    request: proto::AckBufferOperation,
3806    session: Session,
3807) -> Result<()> {
3808    let buffer_id = BufferId::from_proto(request.buffer_id);
3809    session
3810        .db()
3811        .await
3812        .observe_buffer_version(
3813            buffer_id,
3814            session.user_id(),
3815            request.epoch as i32,
3816            &request.version,
3817        )
3818        .await?;
3819    Ok(())
3820}
3821
3822/// Get a Supermaven API key for the user
3823async fn get_supermaven_api_key(
3824    _request: proto::GetSupermavenApiKey,
3825    response: Response<proto::GetSupermavenApiKey>,
3826    session: Session,
3827) -> Result<()> {
3828    let user_id: String = session.user_id().to_string();
3829    if !session.is_staff() {
3830        return Err(anyhow!("supermaven not enabled for this account"))?;
3831    }
3832
3833    let email = session.email().context("user must have an email")?;
3834
3835    let supermaven_admin_api = session
3836        .supermaven_client
3837        .as_ref()
3838        .context("supermaven not configured")?;
3839
3840    let result = supermaven_admin_api
3841        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3842        .await?;
3843
3844    response.send(proto::GetSupermavenApiKeyResponse {
3845        api_key: result.api_key,
3846    })?;
3847
3848    Ok(())
3849}
3850
3851/// Start receiving chat updates for a channel
3852async fn join_channel_chat(
3853    request: proto::JoinChannelChat,
3854    response: Response<proto::JoinChannelChat>,
3855    session: Session,
3856) -> Result<()> {
3857    let channel_id = ChannelId::from_proto(request.channel_id);
3858
3859    let db = session.db().await;
3860    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3861        .await?;
3862    let messages = db
3863        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3864        .await?;
3865    response.send(proto::JoinChannelChatResponse {
3866        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3867        messages,
3868    })?;
3869    Ok(())
3870}
3871
3872/// Stop receiving chat updates for a channel
3873async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3874    let channel_id = ChannelId::from_proto(request.channel_id);
3875    session
3876        .db()
3877        .await
3878        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3879        .await?;
3880    Ok(())
3881}
3882
3883/// Retrieve the chat history for a channel
3884async fn get_channel_messages(
3885    request: proto::GetChannelMessages,
3886    response: Response<proto::GetChannelMessages>,
3887    session: Session,
3888) -> Result<()> {
3889    let channel_id = ChannelId::from_proto(request.channel_id);
3890    let messages = session
3891        .db()
3892        .await
3893        .get_channel_messages(
3894            channel_id,
3895            session.user_id(),
3896            MESSAGE_COUNT_PER_PAGE,
3897            Some(MessageId::from_proto(request.before_message_id)),
3898        )
3899        .await?;
3900    response.send(proto::GetChannelMessagesResponse {
3901        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3902        messages,
3903    })?;
3904    Ok(())
3905}
3906
3907/// Retrieve specific chat messages
3908async fn get_channel_messages_by_id(
3909    request: proto::GetChannelMessagesById,
3910    response: Response<proto::GetChannelMessagesById>,
3911    session: Session,
3912) -> Result<()> {
3913    let message_ids = request
3914        .message_ids
3915        .iter()
3916        .map(|id| MessageId::from_proto(*id))
3917        .collect::<Vec<_>>();
3918    let messages = session
3919        .db()
3920        .await
3921        .get_channel_messages_by_id(session.user_id(), &message_ids)
3922        .await?;
3923    response.send(proto::GetChannelMessagesResponse {
3924        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925        messages,
3926    })?;
3927    Ok(())
3928}
3929
3930/// Retrieve the current users notifications
3931async fn get_notifications(
3932    request: proto::GetNotifications,
3933    response: Response<proto::GetNotifications>,
3934    session: Session,
3935) -> Result<()> {
3936    let notifications = session
3937        .db()
3938        .await
3939        .get_notifications(
3940            session.user_id(),
3941            NOTIFICATION_COUNT_PER_PAGE,
3942            request.before_id.map(db::NotificationId::from_proto),
3943        )
3944        .await?;
3945    response.send(proto::GetNotificationsResponse {
3946        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3947        notifications,
3948    })?;
3949    Ok(())
3950}
3951
3952/// Mark notifications as read
3953async fn mark_notification_as_read(
3954    request: proto::MarkNotificationRead,
3955    response: Response<proto::MarkNotificationRead>,
3956    session: Session,
3957) -> Result<()> {
3958    let database = &session.db().await;
3959    let notifications = database
3960        .mark_notification_as_read_by_id(
3961            session.user_id(),
3962            NotificationId::from_proto(request.notification_id),
3963        )
3964        .await?;
3965    send_notifications(
3966        &*session.connection_pool().await,
3967        &session.peer,
3968        notifications,
3969    );
3970    response.send(proto::Ack {})?;
3971    Ok(())
3972}
3973
3974/// Get the current users information
3975async fn get_private_user_info(
3976    _request: proto::GetPrivateUserInfo,
3977    response: Response<proto::GetPrivateUserInfo>,
3978    session: Session,
3979) -> Result<()> {
3980    let db = session.db().await;
3981
3982    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3983    let user = db
3984        .get_user_by_id(session.user_id())
3985        .await?
3986        .context("user not found")?;
3987    let flags = db.get_user_flags(session.user_id()).await?;
3988
3989    response.send(proto::GetPrivateUserInfoResponse {
3990        metrics_id,
3991        staff: user.admin,
3992        flags,
3993        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3994    })?;
3995    Ok(())
3996}
3997
3998/// Accept the terms of service (tos) on behalf of the current user
3999async fn accept_terms_of_service(
4000    _request: proto::AcceptTermsOfService,
4001    response: Response<proto::AcceptTermsOfService>,
4002    session: Session,
4003) -> Result<()> {
4004    let db = session.db().await;
4005
4006    let accepted_tos_at = Utc::now();
4007    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4008        .await?;
4009
4010    response.send(proto::AcceptTermsOfServiceResponse {
4011        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4012    })?;
4013    Ok(())
4014}
4015
4016/// The minimum account age an account must have in order to use the LLM service.
4017pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4018
4019async fn get_llm_api_token(
4020    _request: proto::GetLlmToken,
4021    response: Response<proto::GetLlmToken>,
4022    session: Session,
4023) -> Result<()> {
4024    let db = session.db().await;
4025
4026    let flags = db.get_user_flags(session.user_id()).await?;
4027
4028    let user_id = session.user_id();
4029    let user = db
4030        .get_user_by_id(user_id)
4031        .await?
4032        .with_context(|| format!("user {user_id} not found"))?;
4033
4034    if user.accepted_tos_at.is_none() {
4035        Err(anyhow!("terms of service not accepted"))?
4036    }
4037
4038    let stripe_client = session
4039        .app_state
4040        .stripe_client
4041        .as_ref()
4042        .context("failed to retrieve Stripe client")?;
4043
4044    let stripe_billing = session
4045        .app_state
4046        .stripe_billing
4047        .as_ref()
4048        .context("failed to retrieve Stripe billing object")?;
4049
4050    let billing_customer = if let Some(billing_customer) =
4051        db.get_billing_customer_by_user_id(user.id).await?
4052    {
4053        billing_customer
4054    } else {
4055        let customer_id = stripe_billing
4056            .find_or_create_customer_by_email(user.email_address.as_deref())
4057            .await?;
4058
4059        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4060            .await?
4061            .context("billing customer not found")?
4062    };
4063
4064    let billing_subscription =
4065        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4066            billing_subscription
4067        } else {
4068            let stripe_customer_id =
4069                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4070
4071            let stripe_subscription = stripe_billing
4072                .subscribe_to_zed_free(stripe_customer_id)
4073                .await?;
4074
4075            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4076                billing_customer_id: billing_customer.id,
4077                kind: Some(SubscriptionKind::ZedFree),
4078                stripe_subscription_id: stripe_subscription.id.to_string(),
4079                stripe_subscription_status: stripe_subscription.status.into(),
4080                stripe_cancellation_reason: None,
4081                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4082                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4083            })
4084            .await?
4085        };
4086
4087    let billing_preferences = db.get_billing_preferences(user.id).await?;
4088
4089    let token = LlmTokenClaims::create(
4090        &user,
4091        session.is_staff(),
4092        billing_customer,
4093        billing_preferences,
4094        &flags,
4095        billing_subscription,
4096        session.system_id.clone(),
4097        &session.app_state.config,
4098    )?;
4099    response.send(proto::GetLlmTokenResponse { token })?;
4100    Ok(())
4101}
4102
4103fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4104    let message = match message {
4105        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4106        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4107        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4108        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4109        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4110            code: frame.code.into(),
4111            reason: frame.reason.as_str().to_owned().into(),
4112        })),
4113        // We should never receive a frame while reading the message, according
4114        // to the `tungstenite` maintainers:
4115        //
4116        // > It cannot occur when you read messages from the WebSocket, but it
4117        // > can be used when you want to send the raw frames (e.g. you want to
4118        // > send the frames to the WebSocket without composing the full message first).
4119        // >
4120        // > — https://github.com/snapview/tungstenite-rs/issues/268
4121        TungsteniteMessage::Frame(_) => {
4122            bail!("received an unexpected frame while reading the message")
4123        }
4124    };
4125
4126    Ok(message)
4127}
4128
4129fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4130    match message {
4131        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4132        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4133        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4134        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4135        AxumMessage::Close(frame) => {
4136            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4137                code: frame.code.into(),
4138                reason: frame.reason.as_ref().into(),
4139            }))
4140        }
4141    }
4142}
4143
4144fn notify_membership_updated(
4145    connection_pool: &mut ConnectionPool,
4146    result: MembershipUpdated,
4147    user_id: UserId,
4148    peer: &Peer,
4149) {
4150    for membership in &result.new_channels.channel_memberships {
4151        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4152    }
4153    for channel_id in &result.removed_channels {
4154        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4155    }
4156
4157    let user_channels_update = proto::UpdateUserChannels {
4158        channel_memberships: result
4159            .new_channels
4160            .channel_memberships
4161            .iter()
4162            .map(|cm| proto::ChannelMembership {
4163                channel_id: cm.channel_id.to_proto(),
4164                role: cm.role.into(),
4165            })
4166            .collect(),
4167        ..Default::default()
4168    };
4169
4170    let mut update = build_channels_update(result.new_channels);
4171    update.delete_channels = result
4172        .removed_channels
4173        .into_iter()
4174        .map(|id| id.to_proto())
4175        .collect();
4176    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4177
4178    for connection_id in connection_pool.user_connection_ids(user_id) {
4179        peer.send(connection_id, user_channels_update.clone())
4180            .trace_err();
4181        peer.send(connection_id, update.clone()).trace_err();
4182    }
4183}
4184
4185fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4186    proto::UpdateUserChannels {
4187        channel_memberships: channels
4188            .channel_memberships
4189            .iter()
4190            .map(|m| proto::ChannelMembership {
4191                channel_id: m.channel_id.to_proto(),
4192                role: m.role.into(),
4193            })
4194            .collect(),
4195        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4196        observed_channel_message_id: channels.observed_channel_messages.clone(),
4197    }
4198}
4199
4200fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4201    let mut update = proto::UpdateChannels::default();
4202
4203    for channel in channels.channels {
4204        update.channels.push(channel.to_proto());
4205    }
4206
4207    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4208    update.latest_channel_message_ids = channels.latest_channel_messages;
4209
4210    for (channel_id, participants) in channels.channel_participants {
4211        update
4212            .channel_participants
4213            .push(proto::ChannelParticipants {
4214                channel_id: channel_id.to_proto(),
4215                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4216            });
4217    }
4218
4219    for channel in channels.invited_channels {
4220        update.channel_invitations.push(channel.to_proto());
4221    }
4222
4223    update
4224}
4225
4226fn build_initial_contacts_update(
4227    contacts: Vec<db::Contact>,
4228    pool: &ConnectionPool,
4229) -> proto::UpdateContacts {
4230    let mut update = proto::UpdateContacts::default();
4231
4232    for contact in contacts {
4233        match contact {
4234            db::Contact::Accepted { user_id, busy } => {
4235                update.contacts.push(contact_for_user(user_id, busy, pool));
4236            }
4237            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4238            db::Contact::Incoming { user_id } => {
4239                update
4240                    .incoming_requests
4241                    .push(proto::IncomingContactRequest {
4242                        requester_id: user_id.to_proto(),
4243                    })
4244            }
4245        }
4246    }
4247
4248    update
4249}
4250
4251fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4252    proto::Contact {
4253        user_id: user_id.to_proto(),
4254        online: pool.is_user_online(user_id),
4255        busy,
4256    }
4257}
4258
4259fn room_updated(room: &proto::Room, peer: &Peer) {
4260    broadcast(
4261        None,
4262        room.participants
4263            .iter()
4264            .filter_map(|participant| Some(participant.peer_id?.into())),
4265        |peer_id| {
4266            peer.send(
4267                peer_id,
4268                proto::RoomUpdated {
4269                    room: Some(room.clone()),
4270                },
4271            )
4272        },
4273    );
4274}
4275
4276fn channel_updated(
4277    channel: &db::channel::Model,
4278    room: &proto::Room,
4279    peer: &Peer,
4280    pool: &ConnectionPool,
4281) {
4282    let participants = room
4283        .participants
4284        .iter()
4285        .map(|p| p.user_id)
4286        .collect::<Vec<_>>();
4287
4288    broadcast(
4289        None,
4290        pool.channel_connection_ids(channel.root_id())
4291            .filter_map(|(channel_id, role)| {
4292                role.can_see_channel(channel.visibility)
4293                    .then_some(channel_id)
4294            }),
4295        |peer_id| {
4296            peer.send(
4297                peer_id,
4298                proto::UpdateChannels {
4299                    channel_participants: vec![proto::ChannelParticipants {
4300                        channel_id: channel.id.to_proto(),
4301                        participant_user_ids: participants.clone(),
4302                    }],
4303                    ..Default::default()
4304                },
4305            )
4306        },
4307    );
4308}
4309
4310async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4311    let db = session.db().await;
4312
4313    let contacts = db.get_contacts(user_id).await?;
4314    let busy = db.is_user_busy(user_id).await?;
4315
4316    let pool = session.connection_pool().await;
4317    let updated_contact = contact_for_user(user_id, busy, &pool);
4318    for contact in contacts {
4319        if let db::Contact::Accepted {
4320            user_id: contact_user_id,
4321            ..
4322        } = contact
4323        {
4324            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4325                session
4326                    .peer
4327                    .send(
4328                        contact_conn_id,
4329                        proto::UpdateContacts {
4330                            contacts: vec![updated_contact.clone()],
4331                            remove_contacts: Default::default(),
4332                            incoming_requests: Default::default(),
4333                            remove_incoming_requests: Default::default(),
4334                            outgoing_requests: Default::default(),
4335                            remove_outgoing_requests: Default::default(),
4336                        },
4337                    )
4338                    .trace_err();
4339            }
4340        }
4341    }
4342    Ok(())
4343}
4344
4345async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4346    let mut contacts_to_update = HashSet::default();
4347
4348    let room_id;
4349    let canceled_calls_to_user_ids;
4350    let livekit_room;
4351    let delete_livekit_room;
4352    let room;
4353    let channel;
4354
4355    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4356        contacts_to_update.insert(session.user_id());
4357
4358        for project in left_room.left_projects.values() {
4359            project_left(project, session);
4360        }
4361
4362        room_id = RoomId::from_proto(left_room.room.id);
4363        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4364        livekit_room = mem::take(&mut left_room.room.livekit_room);
4365        delete_livekit_room = left_room.deleted;
4366        room = mem::take(&mut left_room.room);
4367        channel = mem::take(&mut left_room.channel);
4368
4369        room_updated(&room, &session.peer);
4370    } else {
4371        return Ok(());
4372    }
4373
4374    if let Some(channel) = channel {
4375        channel_updated(
4376            &channel,
4377            &room,
4378            &session.peer,
4379            &*session.connection_pool().await,
4380        );
4381    }
4382
4383    {
4384        let pool = session.connection_pool().await;
4385        for canceled_user_id in canceled_calls_to_user_ids {
4386            for connection_id in pool.user_connection_ids(canceled_user_id) {
4387                session
4388                    .peer
4389                    .send(
4390                        connection_id,
4391                        proto::CallCanceled {
4392                            room_id: room_id.to_proto(),
4393                        },
4394                    )
4395                    .trace_err();
4396            }
4397            contacts_to_update.insert(canceled_user_id);
4398        }
4399    }
4400
4401    for contact_user_id in contacts_to_update {
4402        update_user_contacts(contact_user_id, session).await?;
4403    }
4404
4405    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4406        live_kit
4407            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4408            .await
4409            .trace_err();
4410
4411        if delete_livekit_room {
4412            live_kit.delete_room(livekit_room).await.trace_err();
4413        }
4414    }
4415
4416    Ok(())
4417}
4418
4419async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4420    let left_channel_buffers = session
4421        .db()
4422        .await
4423        .leave_channel_buffers(session.connection_id)
4424        .await?;
4425
4426    for left_buffer in left_channel_buffers {
4427        channel_buffer_updated(
4428            session.connection_id,
4429            left_buffer.connections,
4430            &proto::UpdateChannelBufferCollaborators {
4431                channel_id: left_buffer.channel_id.to_proto(),
4432                collaborators: left_buffer.collaborators,
4433            },
4434            &session.peer,
4435        );
4436    }
4437
4438    Ok(())
4439}
4440
4441fn project_left(project: &db::LeftProject, session: &Session) {
4442    for connection_id in &project.connection_ids {
4443        if project.should_unshare {
4444            session
4445                .peer
4446                .send(
4447                    *connection_id,
4448                    proto::UnshareProject {
4449                        project_id: project.id.to_proto(),
4450                    },
4451                )
4452                .trace_err();
4453        } else {
4454            session
4455                .peer
4456                .send(
4457                    *connection_id,
4458                    proto::RemoveProjectCollaborator {
4459                        project_id: project.id.to_proto(),
4460                        peer_id: Some(session.connection_id.into()),
4461                    },
4462                )
4463                .trace_err();
4464        }
4465    }
4466}
4467
4468pub trait ResultExt {
4469    type Ok;
4470
4471    fn trace_err(self) -> Option<Self::Ok>;
4472}
4473
4474impl<T, E> ResultExt for Result<T, E>
4475where
4476    E: std::fmt::Debug,
4477{
4478    type Ok = T;
4479
4480    #[track_caller]
4481    fn trace_err(self) -> Option<T> {
4482        match self {
4483            Ok(value) => Some(value),
4484            Err(error) => {
4485                tracing::error!("{:?}", error);
4486                None
4487            }
4488        }
4489    }
4490}