rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::{
   9    AppState, Error, Result, auth,
  10    db::{
  11        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  12        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  13        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  14        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  15    },
  16    executor::Executor,
  17};
  18use anyhow::{Context as _, anyhow, bail};
  19use async_tungstenite::tungstenite::{
  20    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  21};
  22use axum::{
  23    Extension, Router, TypedHeader,
  24    body::Body,
  25    extract::{
  26        ConnectInfo, WebSocketUpgrade,
  27        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  28    },
  29    headers::{Header, HeaderName},
  30    http::StatusCode,
  31    middleware,
  32    response::IntoResponse,
  33    routing::get,
  34};
  35use chrono::Utc;
  36use collections::{HashMap, HashSet};
  37pub use connection_pool::{ConnectionPool, ZedVersion};
  38use core::fmt::{self, Debug, Formatter};
  39use reqwest_client::ReqwestClient;
  40use rpc::proto::split_repository_update;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn user(&self) -> &User {
 114        match self {
 115            Principal::User(user) => user,
 116            Principal::Impersonated { user, .. } => user,
 117        }
 118    }
 119
 120    fn update_span(&self, span: &tracing::Span) {
 121        match &self {
 122            Principal::User(user) => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125            }
 126            Principal::Impersonated { user, admin } => {
 127                span.record("user_id", user.id.0);
 128                span.record("login", &user.github_login);
 129                span.record("impersonator", &admin.github_login);
 130            }
 131        }
 132    }
 133}
 134
 135#[derive(Clone)]
 136struct Session {
 137    principal: Principal,
 138    connection_id: ConnectionId,
 139    db: Arc<tokio::sync::Mutex<DbHandle>>,
 140    peer: Arc<Peer>,
 141    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 142    app_state: Arc<AppState>,
 143    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    system_id: Option<String>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn is_staff(&self) -> bool {
 172        match &self.principal {
 173            Principal::User(user) => user.admin,
 174            Principal::Impersonated { .. } => true,
 175        }
 176    }
 177
 178    fn user_id(&self) -> UserId {
 179        match &self.principal {
 180            Principal::User(user) => user.id,
 181            Principal::Impersonated { user, .. } => user.id,
 182        }
 183    }
 184
 185    pub fn email(&self) -> Option<String> {
 186        match &self.principal {
 187            Principal::User(user) => user.email_address.clone(),
 188            Principal::Impersonated { user, .. } => user.email_address.clone(),
 189        }
 190    }
 191}
 192
 193impl Debug for Session {
 194    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 195        let mut result = f.debug_struct("Session");
 196        match &self.principal {
 197            Principal::User(user) => {
 198                result.field("user", &user.github_login);
 199            }
 200            Principal::Impersonated { user, admin } => {
 201                result.field("user", &user.github_login);
 202                result.field("impersonator", &admin.github_login);
 203            }
 204        }
 205        result.field("connection_id", &self.connection_id).finish()
 206    }
 207}
 208
 209struct DbHandle(Arc<Database>);
 210
 211impl Deref for DbHandle {
 212    type Target = Database;
 213
 214    fn deref(&self) -> &Self::Target {
 215        self.0.as_ref()
 216    }
 217}
 218
 219pub struct Server {
 220    id: parking_lot::Mutex<ServerId>,
 221    peer: Arc<Peer>,
 222    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 223    app_state: Arc<AppState>,
 224    handlers: HashMap<TypeId, MessageHandler>,
 225    teardown: watch::Sender<bool>,
 226}
 227
 228pub(crate) struct ConnectionPoolGuard<'a> {
 229    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 230    _not_send: PhantomData<Rc<()>>,
 231}
 232
 233#[derive(Serialize)]
 234pub struct ServerSnapshot<'a> {
 235    peer: &'a Peer,
 236    #[serde(serialize_with = "serialize_deref")]
 237    connection_pool: ConnectionPoolGuard<'a>,
 238}
 239
 240pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 241where
 242    S: Serializer,
 243    T: Deref<Target = U>,
 244    U: Serialize,
 245{
 246    Serialize::serialize(value.deref(), serializer)
 247}
 248
 249impl Server {
 250    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 251        let mut server = Self {
 252            id: parking_lot::Mutex::new(id),
 253            peer: Peer::new(id.0 as u32),
 254            app_state: app_state.clone(),
 255            connection_pool: Default::default(),
 256            handlers: Default::default(),
 257            teardown: watch::channel(false).0,
 258        };
 259
 260        server
 261            .add_request_handler(ping)
 262            .add_request_handler(create_room)
 263            .add_request_handler(join_room)
 264            .add_request_handler(rejoin_room)
 265            .add_request_handler(leave_room)
 266            .add_request_handler(set_room_participant_role)
 267            .add_request_handler(call)
 268            .add_request_handler(cancel_call)
 269            .add_message_handler(decline_call)
 270            .add_request_handler(update_participant_location)
 271            .add_request_handler(share_project)
 272            .add_message_handler(unshare_project)
 273            .add_request_handler(join_project)
 274            .add_message_handler(leave_project)
 275            .add_request_handler(update_project)
 276            .add_request_handler(update_worktree)
 277            .add_request_handler(update_repository)
 278            .add_request_handler(remove_repository)
 279            .add_message_handler(start_language_server)
 280            .add_message_handler(update_language_server)
 281            .add_message_handler(update_diagnostic_summary)
 282            .add_message_handler(update_worktree_settings)
 283            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 287            .add_request_handler(forward_find_search_candidates_request)
 288            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 293            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 294            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 295            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 296            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 297            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 298            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 299            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 303            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 304            .add_request_handler(
 305                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 306            )
 307            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 311            .add_request_handler(
 312                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 313            )
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 332            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 342            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 345            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 348            .add_message_handler(create_buffer_for_peer)
 349            .add_request_handler(update_buffer)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 356            .add_request_handler(get_users)
 357            .add_request_handler(fuzzy_search_users)
 358            .add_request_handler(request_contact)
 359            .add_request_handler(remove_contact)
 360            .add_request_handler(respond_to_contact_request)
 361            .add_message_handler(subscribe_to_channels)
 362            .add_request_handler(create_channel)
 363            .add_request_handler(delete_channel)
 364            .add_request_handler(invite_channel_member)
 365            .add_request_handler(remove_channel_member)
 366            .add_request_handler(set_channel_member_role)
 367            .add_request_handler(set_channel_visibility)
 368            .add_request_handler(rename_channel)
 369            .add_request_handler(join_channel_buffer)
 370            .add_request_handler(leave_channel_buffer)
 371            .add_message_handler(update_channel_buffer)
 372            .add_request_handler(rejoin_channel_buffers)
 373            .add_request_handler(get_channel_members)
 374            .add_request_handler(respond_to_channel_invite)
 375            .add_request_handler(join_channel)
 376            .add_request_handler(join_channel_chat)
 377            .add_message_handler(leave_channel_chat)
 378            .add_request_handler(send_channel_message)
 379            .add_request_handler(remove_channel_message)
 380            .add_request_handler(update_channel_message)
 381            .add_request_handler(get_channel_messages)
 382            .add_request_handler(get_channel_messages_by_id)
 383            .add_request_handler(get_notifications)
 384            .add_request_handler(mark_notification_as_read)
 385            .add_request_handler(move_channel)
 386            .add_request_handler(follow)
 387            .add_message_handler(unfollow)
 388            .add_message_handler(update_followers)
 389            .add_request_handler(get_private_user_info)
 390            .add_request_handler(get_llm_api_token)
 391            .add_request_handler(accept_terms_of_service)
 392            .add_message_handler(acknowledge_channel_message)
 393            .add_message_handler(acknowledge_buffer_version)
 394            .add_request_handler(get_supermaven_api_key)
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 396            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 401            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 404            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 407            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 408            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 410            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 411            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 412            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 414            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 415            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 416            .add_message_handler(update_context);
 417
 418        Arc::new(server)
 419    }
 420
 421    pub async fn start(&self) -> Result<()> {
 422        let server_id = *self.id.lock();
 423        let app_state = self.app_state.clone();
 424        let peer = self.peer.clone();
 425        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 426        let pool = self.connection_pool.clone();
 427        let livekit_client = self.app_state.livekit_client.clone();
 428
 429        let span = info_span!("start server");
 430        self.app_state.executor.spawn_detached(
 431            async move {
 432                tracing::info!("waiting for cleanup timeout");
 433                timeout.await;
 434                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 435                if let Some((room_ids, channel_ids)) = app_state
 436                    .db
 437                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 438                    .await
 439                    .trace_err()
 440                {
 441                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 442                    tracing::info!(
 443                        stale_channel_buffer_count = channel_ids.len(),
 444                        "retrieved stale channel buffers"
 445                    );
 446
 447                    for channel_id in channel_ids {
 448                        if let Some(refreshed_channel_buffer) = app_state
 449                            .db
 450                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 451                            .await
 452                            .trace_err()
 453                        {
 454                            for connection_id in refreshed_channel_buffer.connection_ids {
 455                                peer.send(
 456                                    connection_id,
 457                                    proto::UpdateChannelBufferCollaborators {
 458                                        channel_id: channel_id.to_proto(),
 459                                        collaborators: refreshed_channel_buffer
 460                                            .collaborators
 461                                            .clone(),
 462                                    },
 463                                )
 464                                .trace_err();
 465                            }
 466                        }
 467                    }
 468
 469                    for room_id in room_ids {
 470                        let mut contacts_to_update = HashSet::default();
 471                        let mut canceled_calls_to_user_ids = Vec::new();
 472                        let mut livekit_room = String::new();
 473                        let mut delete_livekit_room = false;
 474
 475                        if let Some(mut refreshed_room) = app_state
 476                            .db
 477                            .clear_stale_room_participants(room_id, server_id)
 478                            .await
 479                            .trace_err()
 480                        {
 481                            tracing::info!(
 482                                room_id = room_id.0,
 483                                new_participant_count = refreshed_room.room.participants.len(),
 484                                "refreshed room"
 485                            );
 486                            room_updated(&refreshed_room.room, &peer);
 487                            if let Some(channel) = refreshed_room.channel.as_ref() {
 488                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 489                            }
 490                            contacts_to_update
 491                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 492                            contacts_to_update
 493                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 494                            canceled_calls_to_user_ids =
 495                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 496                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 497                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 498                        }
 499
 500                        {
 501                            let pool = pool.lock();
 502                            for canceled_user_id in canceled_calls_to_user_ids {
 503                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 504                                    peer.send(
 505                                        connection_id,
 506                                        proto::CallCanceled {
 507                                            room_id: room_id.to_proto(),
 508                                        },
 509                                    )
 510                                    .trace_err();
 511                                }
 512                            }
 513                        }
 514
 515                        for user_id in contacts_to_update {
 516                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 517                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 518                            if let Some((busy, contacts)) = busy.zip(contacts) {
 519                                let pool = pool.lock();
 520                                let updated_contact = contact_for_user(user_id, busy, &pool);
 521                                for contact in contacts {
 522                                    if let db::Contact::Accepted {
 523                                        user_id: contact_user_id,
 524                                        ..
 525                                    } = contact
 526                                    {
 527                                        for contact_conn_id in
 528                                            pool.user_connection_ids(contact_user_id)
 529                                        {
 530                                            peer.send(
 531                                                contact_conn_id,
 532                                                proto::UpdateContacts {
 533                                                    contacts: vec![updated_contact.clone()],
 534                                                    remove_contacts: Default::default(),
 535                                                    incoming_requests: Default::default(),
 536                                                    remove_incoming_requests: Default::default(),
 537                                                    outgoing_requests: Default::default(),
 538                                                    remove_outgoing_requests: Default::default(),
 539                                                },
 540                                            )
 541                                            .trace_err();
 542                                        }
 543                                    }
 544                                }
 545                            }
 546                        }
 547
 548                        if let Some(live_kit) = livekit_client.as_ref() {
 549                            if delete_livekit_room {
 550                                live_kit.delete_room(livekit_room).await.trace_err();
 551                            }
 552                        }
 553                    }
 554                }
 555
 556                app_state
 557                    .db
 558                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 559                    .await
 560                    .trace_err();
 561            }
 562            .instrument(span),
 563        );
 564        Ok(())
 565    }
 566
 567    pub fn teardown(&self) {
 568        self.peer.teardown();
 569        self.connection_pool.lock().reset();
 570        let _ = self.teardown.send(true);
 571    }
 572
 573    #[cfg(test)]
 574    pub fn reset(&self, id: ServerId) {
 575        self.teardown();
 576        *self.id.lock() = id;
 577        self.peer.reset(id.0 as u32);
 578        let _ = self.teardown.send(false);
 579    }
 580
 581    #[cfg(test)]
 582    pub fn id(&self) -> ServerId {
 583        *self.id.lock()
 584    }
 585
 586    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 587    where
 588        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 589        Fut: 'static + Send + Future<Output = Result<()>>,
 590        M: EnvelopedMessage,
 591    {
 592        let prev_handler = self.handlers.insert(
 593            TypeId::of::<M>(),
 594            Box::new(move |envelope, session| {
 595                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 596                let received_at = envelope.received_at;
 597                tracing::info!("message received");
 598                let start_time = Instant::now();
 599                let future = (handler)(*envelope, session);
 600                async move {
 601                    let result = future.await;
 602                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 603                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 604                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 605                    let payload_type = M::NAME;
 606
 607                    match result {
 608                        Err(error) => {
 609                            tracing::error!(
 610                                ?error,
 611                                total_duration_ms,
 612                                processing_duration_ms,
 613                                queue_duration_ms,
 614                                payload_type,
 615                                "error handling message"
 616                            )
 617                        }
 618                        Ok(()) => tracing::info!(
 619                            total_duration_ms,
 620                            processing_duration_ms,
 621                            queue_duration_ms,
 622                            "finished handling message"
 623                        ),
 624                    }
 625                }
 626                .boxed()
 627            }),
 628        );
 629        if prev_handler.is_some() {
 630            panic!("registered a handler for the same message twice");
 631        }
 632        self
 633    }
 634
 635    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 636    where
 637        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 638        Fut: 'static + Send + Future<Output = Result<()>>,
 639        M: EnvelopedMessage,
 640    {
 641        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 642        self
 643    }
 644
 645    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 646    where
 647        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 648        Fut: Send + Future<Output = Result<()>>,
 649        M: RequestMessage,
 650    {
 651        let handler = Arc::new(handler);
 652        self.add_handler(move |envelope, session| {
 653            let receipt = envelope.receipt();
 654            let handler = handler.clone();
 655            async move {
 656                let peer = session.peer.clone();
 657                let responded = Arc::new(AtomicBool::default());
 658                let response = Response {
 659                    peer: peer.clone(),
 660                    responded: responded.clone(),
 661                    receipt,
 662                };
 663                match (handler)(envelope.payload, response, session).await {
 664                    Ok(()) => {
 665                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 666                            Ok(())
 667                        } else {
 668                            Err(anyhow!("handler did not send a response"))?
 669                        }
 670                    }
 671                    Err(error) => {
 672                        let proto_err = match &error {
 673                            Error::Internal(err) => err.to_proto(),
 674                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 675                        };
 676                        peer.respond_with_error(receipt, proto_err)?;
 677                        Err(error)
 678                    }
 679                }
 680            }
 681        })
 682    }
 683
 684    pub fn handle_connection(
 685        self: &Arc<Self>,
 686        connection: Connection,
 687        address: String,
 688        principal: Principal,
 689        zed_version: ZedVersion,
 690        geoip_country_code: Option<String>,
 691        system_id: Option<String>,
 692        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 693        executor: Executor,
 694    ) -> impl Future<Output = ()> + use<> {
 695        let this = self.clone();
 696        let span = info_span!("handle connection", %address,
 697            connection_id=field::Empty,
 698            user_id=field::Empty,
 699            login=field::Empty,
 700            impersonator=field::Empty,
 701            geoip_country_code=field::Empty
 702        );
 703        principal.update_span(&span);
 704        if let Some(country_code) = geoip_country_code.as_ref() {
 705            span.record("geoip_country_code", country_code);
 706        }
 707
 708        let mut teardown = self.teardown.subscribe();
 709        async move {
 710            if *teardown.borrow() {
 711                tracing::error!("server is tearing down");
 712                return
 713            }
 714            let (connection_id, handle_io, mut incoming_rx) = this
 715                .peer
 716                .add_connection(connection, {
 717                    let executor = executor.clone();
 718                    move |duration| executor.sleep(duration)
 719                });
 720            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 721
 722            tracing::info!("connection opened");
 723
 724            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 725            let http_client = match ReqwestClient::user_agent(&user_agent) {
 726                Ok(http_client) => Arc::new(http_client),
 727                Err(error) => {
 728                    tracing::error!(?error, "failed to create HTTP client");
 729                    return;
 730                }
 731            };
 732
 733            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 734                    supermaven_admin_api_key.to_string(),
 735                    http_client.clone(),
 736                )));
 737
 738            let session = Session {
 739                principal: principal.clone(),
 740                connection_id,
 741                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 742                peer: this.peer.clone(),
 743                connection_pool: this.connection_pool.clone(),
 744                app_state: this.app_state.clone(),
 745                geoip_country_code,
 746                system_id,
 747                _executor: executor.clone(),
 748                supermaven_client,
 749            };
 750
 751            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 752                tracing::error!(?error, "failed to send initial client update");
 753                return;
 754            }
 755
 756            let handle_io = handle_io.fuse();
 757            futures::pin_mut!(handle_io);
 758
 759            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 760            // This prevents deadlocks when e.g., client A performs a request to client B and
 761            // client B performs a request to client A. If both clients stop processing further
 762            // messages until their respective request completes, they won't have a chance to
 763            // respond to the other client's request and cause a deadlock.
 764            //
 765            // This arrangement ensures we will attempt to process earlier messages first, but fall
 766            // back to processing messages arrived later in the spirit of making progress.
 767            let mut foreground_message_handlers = FuturesUnordered::new();
 768            let concurrent_handlers = Arc::new(Semaphore::new(256));
 769            loop {
 770                let next_message = async {
 771                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 772                    let message = incoming_rx.next().await;
 773                    (permit, message)
 774                }.fuse();
 775                futures::pin_mut!(next_message);
 776                futures::select_biased! {
 777                    _ = teardown.changed().fuse() => return,
 778                    result = handle_io => {
 779                        if let Err(error) = result {
 780                            tracing::error!(?error, "error handling I/O");
 781                        }
 782                        break;
 783                    }
 784                    _ = foreground_message_handlers.next() => {}
 785                    next_message = next_message => {
 786                        let (permit, message) = next_message;
 787                        if let Some(message) = message {
 788                            let type_name = message.payload_type_name();
 789                            // note: we copy all the fields from the parent span so we can query them in the logs.
 790                            // (https://github.com/tokio-rs/tracing/issues/2670).
 791                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 792                                user_id=field::Empty,
 793                                login=field::Empty,
 794                                impersonator=field::Empty,
 795                            );
 796                            principal.update_span(&span);
 797                            let span_enter = span.enter();
 798                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 799                                let is_background = message.is_background();
 800                                let handle_message = (handler)(message, session.clone());
 801                                drop(span_enter);
 802
 803                                let handle_message = async move {
 804                                    handle_message.await;
 805                                    drop(permit);
 806                                }.instrument(span);
 807                                if is_background {
 808                                    executor.spawn_detached(handle_message);
 809                                } else {
 810                                    foreground_message_handlers.push(handle_message);
 811                                }
 812                            } else {
 813                                tracing::error!("no message handler");
 814                            }
 815                        } else {
 816                            tracing::info!("connection closed");
 817                            break;
 818                        }
 819                    }
 820                }
 821            }
 822
 823            drop(foreground_message_handlers);
 824            tracing::info!("signing out");
 825            if let Err(error) = connection_lost(session, teardown, executor).await {
 826                tracing::error!(?error, "error signing out");
 827            }
 828
 829        }.instrument(span)
 830    }
 831
 832    async fn send_initial_client_update(
 833        &self,
 834        connection_id: ConnectionId,
 835        zed_version: ZedVersion,
 836        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 837        session: &Session,
 838    ) -> Result<()> {
 839        self.peer.send(
 840            connection_id,
 841            proto::Hello {
 842                peer_id: Some(connection_id.into()),
 843            },
 844        )?;
 845        tracing::info!("sent hello message");
 846        if let Some(send_connection_id) = send_connection_id.take() {
 847            let _ = send_connection_id.send(connection_id);
 848        }
 849
 850        match &session.principal {
 851            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 852                if !user.connected_once {
 853                    self.peer.send(connection_id, proto::ShowContacts {})?;
 854                    self.app_state
 855                        .db
 856                        .set_user_connected_once(user.id, true)
 857                        .await?;
 858                }
 859
 860                update_user_plan(session).await?;
 861
 862                let contacts = self.app_state.db.get_contacts(user.id).await?;
 863
 864                {
 865                    let mut pool = self.connection_pool.lock();
 866                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 867                    self.peer.send(
 868                        connection_id,
 869                        build_initial_contacts_update(contacts, &pool),
 870                    )?;
 871                }
 872
 873                if should_auto_subscribe_to_channels(zed_version) {
 874                    subscribe_user_to_channels(user.id, session).await?;
 875                }
 876
 877                if let Some(incoming_call) =
 878                    self.app_state.db.incoming_call_for_user(user.id).await?
 879                {
 880                    self.peer.send(connection_id, incoming_call)?;
 881                }
 882
 883                update_user_contacts(user.id, session).await?;
 884            }
 885        }
 886
 887        Ok(())
 888    }
 889
 890    pub async fn invite_code_redeemed(
 891        self: &Arc<Self>,
 892        inviter_id: UserId,
 893        invitee_id: UserId,
 894    ) -> Result<()> {
 895        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 896            if let Some(code) = &user.invite_code {
 897                let pool = self.connection_pool.lock();
 898                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 899                for connection_id in pool.user_connection_ids(inviter_id) {
 900                    self.peer.send(
 901                        connection_id,
 902                        proto::UpdateContacts {
 903                            contacts: vec![invitee_contact.clone()],
 904                            ..Default::default()
 905                        },
 906                    )?;
 907                    self.peer.send(
 908                        connection_id,
 909                        proto::UpdateInviteInfo {
 910                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 911                            count: user.invite_count as u32,
 912                        },
 913                    )?;
 914                }
 915            }
 916        }
 917        Ok(())
 918    }
 919
 920    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 921        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 922            if let Some(invite_code) = &user.invite_code {
 923                let pool = self.connection_pool.lock();
 924                for connection_id in pool.user_connection_ids(user_id) {
 925                    self.peer.send(
 926                        connection_id,
 927                        proto::UpdateInviteInfo {
 928                            url: format!(
 929                                "{}{}",
 930                                self.app_state.config.invite_link_prefix, invite_code
 931                            ),
 932                            count: user.invite_count as u32,
 933                        },
 934                    )?;
 935                }
 936            }
 937        }
 938        Ok(())
 939    }
 940
 941    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 942        let user = self
 943            .app_state
 944            .db
 945            .get_user_by_id(user_id)
 946            .await?
 947            .context("user not found")?;
 948
 949        let update_user_plan = make_update_user_plan_message(
 950            &user,
 951            user.admin,
 952            &self.app_state.db,
 953            self.app_state.llm_db.clone(),
 954        )
 955        .await?;
 956
 957        let pool = self.connection_pool.lock();
 958        for connection_id in pool.user_connection_ids(user_id) {
 959            self.peer
 960                .send(connection_id, update_user_plan.clone())
 961                .trace_err();
 962        }
 963
 964        Ok(())
 965    }
 966
 967    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 968        let pool = self.connection_pool.lock();
 969        for connection_id in pool.user_connection_ids(user_id) {
 970            self.peer
 971                .send(connection_id, proto::RefreshLlmToken {})
 972                .trace_err();
 973        }
 974    }
 975
 976    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 977        ServerSnapshot {
 978            connection_pool: ConnectionPoolGuard {
 979                guard: self.connection_pool.lock(),
 980                _not_send: PhantomData,
 981            },
 982            peer: &self.peer,
 983        }
 984    }
 985}
 986
 987impl Deref for ConnectionPoolGuard<'_> {
 988    type Target = ConnectionPool;
 989
 990    fn deref(&self) -> &Self::Target {
 991        &self.guard
 992    }
 993}
 994
 995impl DerefMut for ConnectionPoolGuard<'_> {
 996    fn deref_mut(&mut self) -> &mut Self::Target {
 997        &mut self.guard
 998    }
 999}
1000
1001impl Drop for ConnectionPoolGuard<'_> {
1002    fn drop(&mut self) {
1003        #[cfg(test)]
1004        self.check_invariants();
1005    }
1006}
1007
1008fn broadcast<F>(
1009    sender_id: Option<ConnectionId>,
1010    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1011    mut f: F,
1012) where
1013    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1014{
1015    for receiver_id in receiver_ids {
1016        if Some(receiver_id) != sender_id {
1017            if let Err(error) = f(receiver_id) {
1018                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1019            }
1020        }
1021    }
1022}
1023
1024pub struct ProtocolVersion(u32);
1025
1026impl Header for ProtocolVersion {
1027    fn name() -> &'static HeaderName {
1028        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1029        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1030    }
1031
1032    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1033    where
1034        Self: Sized,
1035        I: Iterator<Item = &'i axum::http::HeaderValue>,
1036    {
1037        let version = values
1038            .next()
1039            .ok_or_else(axum::headers::Error::invalid)?
1040            .to_str()
1041            .map_err(|_| axum::headers::Error::invalid())?
1042            .parse()
1043            .map_err(|_| axum::headers::Error::invalid())?;
1044        Ok(Self(version))
1045    }
1046
1047    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1048        values.extend([self.0.to_string().parse().unwrap()]);
1049    }
1050}
1051
1052pub struct AppVersionHeader(SemanticVersion);
1053impl Header for AppVersionHeader {
1054    fn name() -> &'static HeaderName {
1055        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1057    }
1058
1059    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060    where
1061        Self: Sized,
1062        I: Iterator<Item = &'i axum::http::HeaderValue>,
1063    {
1064        let version = values
1065            .next()
1066            .ok_or_else(axum::headers::Error::invalid)?
1067            .to_str()
1068            .map_err(|_| axum::headers::Error::invalid())?
1069            .parse()
1070            .map_err(|_| axum::headers::Error::invalid())?;
1071        Ok(Self(version))
1072    }
1073
1074    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075        values.extend([self.0.to_string().parse().unwrap()]);
1076    }
1077}
1078
1079pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1080    Router::new()
1081        .route("/rpc", get(handle_websocket_request))
1082        .layer(
1083            ServiceBuilder::new()
1084                .layer(Extension(server.app_state.clone()))
1085                .layer(middleware::from_fn(auth::validate_header)),
1086        )
1087        .route("/metrics", get(handle_metrics))
1088        .layer(Extension(server))
1089}
1090
1091pub async fn handle_websocket_request(
1092    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1093    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1094    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1095    Extension(server): Extension<Arc<Server>>,
1096    Extension(principal): Extension<Principal>,
1097    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1098    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1099    ws: WebSocketUpgrade,
1100) -> axum::response::Response {
1101    if protocol_version != rpc::PROTOCOL_VERSION {
1102        return (
1103            StatusCode::UPGRADE_REQUIRED,
1104            "client must be upgraded".to_string(),
1105        )
1106            .into_response();
1107    }
1108
1109    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1110        return (
1111            StatusCode::UPGRADE_REQUIRED,
1112            "no version header found".to_string(),
1113        )
1114            .into_response();
1115    };
1116
1117    if !version.can_collaborate() {
1118        return (
1119            StatusCode::UPGRADE_REQUIRED,
1120            "client must be upgraded".to_string(),
1121        )
1122            .into_response();
1123    }
1124
1125    let socket_address = socket_address.to_string();
1126    ws.on_upgrade(move |socket| {
1127        let socket = socket
1128            .map_ok(to_tungstenite_message)
1129            .err_into()
1130            .with(|message| async move { to_axum_message(message) });
1131        let connection = Connection::new(Box::pin(socket));
1132        async move {
1133            server
1134                .handle_connection(
1135                    connection,
1136                    socket_address,
1137                    principal,
1138                    version,
1139                    country_code_header.map(|header| header.to_string()),
1140                    system_id_header.map(|header| header.to_string()),
1141                    None,
1142                    Executor::Production,
1143                )
1144                .await;
1145        }
1146    })
1147}
1148
1149pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1150    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151    let connections_metric = CONNECTIONS_METRIC
1152        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1153
1154    let connections = server
1155        .connection_pool
1156        .lock()
1157        .connections()
1158        .filter(|connection| !connection.admin)
1159        .count();
1160    connections_metric.set(connections as _);
1161
1162    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1164        register_int_gauge!(
1165            "shared_projects",
1166            "number of open projects with one or more guests"
1167        )
1168        .unwrap()
1169    });
1170
1171    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1172    shared_projects_metric.set(shared_projects as _);
1173
1174    let encoder = prometheus::TextEncoder::new();
1175    let metric_families = prometheus::gather();
1176    let encoded_metrics = encoder
1177        .encode_to_string(&metric_families)
1178        .map_err(|err| anyhow!("{err}"))?;
1179    Ok(encoded_metrics)
1180}
1181
1182#[instrument(err, skip(executor))]
1183async fn connection_lost(
1184    session: Session,
1185    mut teardown: watch::Receiver<bool>,
1186    executor: Executor,
1187) -> Result<()> {
1188    session.peer.disconnect(session.connection_id);
1189    session
1190        .connection_pool()
1191        .await
1192        .remove_connection(session.connection_id)?;
1193
1194    session
1195        .db()
1196        .await
1197        .connection_lost(session.connection_id)
1198        .await
1199        .trace_err();
1200
1201    futures::select_biased! {
1202        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1203
1204            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1205            leave_room_for_session(&session, session.connection_id).await.trace_err();
1206            leave_channel_buffers_for_session(&session)
1207                .await
1208                .trace_err();
1209
1210            if !session
1211                .connection_pool()
1212                .await
1213                .is_user_online(session.user_id())
1214            {
1215                let db = session.db().await;
1216                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1217                    room_updated(&room, &session.peer);
1218                }
1219            }
1220
1221            update_user_contacts(session.user_id(), &session).await?;
1222        },
1223        _ = teardown.changed().fuse() => {}
1224    }
1225
1226    Ok(())
1227}
1228
1229/// Acknowledges a ping from a client, used to keep the connection alive.
1230async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1231    response.send(proto::Ack {})?;
1232    Ok(())
1233}
1234
1235/// Creates a new room for calling (outside of channels)
1236async fn create_room(
1237    _request: proto::CreateRoom,
1238    response: Response<proto::CreateRoom>,
1239    session: Session,
1240) -> Result<()> {
1241    let livekit_room = nanoid::nanoid!(30);
1242
1243    let live_kit_connection_info = util::maybe!(async {
1244        let live_kit = session.app_state.livekit_client.as_ref();
1245        let live_kit = live_kit?;
1246        let user_id = session.user_id().to_string();
1247
1248        let token = live_kit
1249            .room_token(&livekit_room, &user_id.to_string())
1250            .trace_err()?;
1251
1252        Some(proto::LiveKitConnectionInfo {
1253            server_url: live_kit.url().into(),
1254            token,
1255            can_publish: true,
1256        })
1257    })
1258    .await;
1259
1260    let room = session
1261        .db()
1262        .await
1263        .create_room(session.user_id(), session.connection_id, &livekit_room)
1264        .await?;
1265
1266    response.send(proto::CreateRoomResponse {
1267        room: Some(room.clone()),
1268        live_kit_connection_info,
1269    })?;
1270
1271    update_user_contacts(session.user_id(), &session).await?;
1272    Ok(())
1273}
1274
1275/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1276async fn join_room(
1277    request: proto::JoinRoom,
1278    response: Response<proto::JoinRoom>,
1279    session: Session,
1280) -> Result<()> {
1281    let room_id = RoomId::from_proto(request.id);
1282
1283    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1284
1285    if let Some(channel_id) = channel_id {
1286        return join_channel_internal(channel_id, Box::new(response), session).await;
1287    }
1288
1289    let joined_room = {
1290        let room = session
1291            .db()
1292            .await
1293            .join_room(room_id, session.user_id(), session.connection_id)
1294            .await?;
1295        room_updated(&room.room, &session.peer);
1296        room.into_inner()
1297    };
1298
1299    for connection_id in session
1300        .connection_pool()
1301        .await
1302        .user_connection_ids(session.user_id())
1303    {
1304        session
1305            .peer
1306            .send(
1307                connection_id,
1308                proto::CallCanceled {
1309                    room_id: room_id.to_proto(),
1310                },
1311            )
1312            .trace_err();
1313    }
1314
1315    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1316    {
1317        live_kit
1318            .room_token(
1319                &joined_room.room.livekit_room,
1320                &session.user_id().to_string(),
1321            )
1322            .trace_err()
1323            .map(|token| proto::LiveKitConnectionInfo {
1324                server_url: live_kit.url().into(),
1325                token,
1326                can_publish: true,
1327            })
1328    } else {
1329        None
1330    };
1331
1332    response.send(proto::JoinRoomResponse {
1333        room: Some(joined_room.room),
1334        channel_id: None,
1335        live_kit_connection_info,
1336    })?;
1337
1338    update_user_contacts(session.user_id(), &session).await?;
1339    Ok(())
1340}
1341
1342/// Rejoin room is used to reconnect to a room after connection errors.
1343async fn rejoin_room(
1344    request: proto::RejoinRoom,
1345    response: Response<proto::RejoinRoom>,
1346    session: Session,
1347) -> Result<()> {
1348    let room;
1349    let channel;
1350    {
1351        let mut rejoined_room = session
1352            .db()
1353            .await
1354            .rejoin_room(request, session.user_id(), session.connection_id)
1355            .await?;
1356
1357        response.send(proto::RejoinRoomResponse {
1358            room: Some(rejoined_room.room.clone()),
1359            reshared_projects: rejoined_room
1360                .reshared_projects
1361                .iter()
1362                .map(|project| proto::ResharedProject {
1363                    id: project.id.to_proto(),
1364                    collaborators: project
1365                        .collaborators
1366                        .iter()
1367                        .map(|collaborator| collaborator.to_proto())
1368                        .collect(),
1369                })
1370                .collect(),
1371            rejoined_projects: rejoined_room
1372                .rejoined_projects
1373                .iter()
1374                .map(|rejoined_project| rejoined_project.to_proto())
1375                .collect(),
1376        })?;
1377        room_updated(&rejoined_room.room, &session.peer);
1378
1379        for project in &rejoined_room.reshared_projects {
1380            for collaborator in &project.collaborators {
1381                session
1382                    .peer
1383                    .send(
1384                        collaborator.connection_id,
1385                        proto::UpdateProjectCollaborator {
1386                            project_id: project.id.to_proto(),
1387                            old_peer_id: Some(project.old_connection_id.into()),
1388                            new_peer_id: Some(session.connection_id.into()),
1389                        },
1390                    )
1391                    .trace_err();
1392            }
1393
1394            broadcast(
1395                Some(session.connection_id),
1396                project
1397                    .collaborators
1398                    .iter()
1399                    .map(|collaborator| collaborator.connection_id),
1400                |connection_id| {
1401                    session.peer.forward_send(
1402                        session.connection_id,
1403                        connection_id,
1404                        proto::UpdateProject {
1405                            project_id: project.id.to_proto(),
1406                            worktrees: project.worktrees.clone(),
1407                        },
1408                    )
1409                },
1410            );
1411        }
1412
1413        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1414
1415        let rejoined_room = rejoined_room.into_inner();
1416
1417        room = rejoined_room.room;
1418        channel = rejoined_room.channel;
1419    }
1420
1421    if let Some(channel) = channel {
1422        channel_updated(
1423            &channel,
1424            &room,
1425            &session.peer,
1426            &*session.connection_pool().await,
1427        );
1428    }
1429
1430    update_user_contacts(session.user_id(), &session).await?;
1431    Ok(())
1432}
1433
1434fn notify_rejoined_projects(
1435    rejoined_projects: &mut Vec<RejoinedProject>,
1436    session: &Session,
1437) -> Result<()> {
1438    for project in rejoined_projects.iter() {
1439        for collaborator in &project.collaborators {
1440            session
1441                .peer
1442                .send(
1443                    collaborator.connection_id,
1444                    proto::UpdateProjectCollaborator {
1445                        project_id: project.id.to_proto(),
1446                        old_peer_id: Some(project.old_connection_id.into()),
1447                        new_peer_id: Some(session.connection_id.into()),
1448                    },
1449                )
1450                .trace_err();
1451        }
1452    }
1453
1454    for project in rejoined_projects {
1455        for worktree in mem::take(&mut project.worktrees) {
1456            // Stream this worktree's entries.
1457            let message = proto::UpdateWorktree {
1458                project_id: project.id.to_proto(),
1459                worktree_id: worktree.id,
1460                abs_path: worktree.abs_path.clone(),
1461                root_name: worktree.root_name,
1462                updated_entries: worktree.updated_entries,
1463                removed_entries: worktree.removed_entries,
1464                scan_id: worktree.scan_id,
1465                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1466                updated_repositories: worktree.updated_repositories,
1467                removed_repositories: worktree.removed_repositories,
1468            };
1469            for update in proto::split_worktree_update(message) {
1470                session.peer.send(session.connection_id, update)?;
1471            }
1472
1473            // Stream this worktree's diagnostics.
1474            for summary in worktree.diagnostic_summaries {
1475                session.peer.send(
1476                    session.connection_id,
1477                    proto::UpdateDiagnosticSummary {
1478                        project_id: project.id.to_proto(),
1479                        worktree_id: worktree.id,
1480                        summary: Some(summary),
1481                    },
1482                )?;
1483            }
1484
1485            for settings_file in worktree.settings_files {
1486                session.peer.send(
1487                    session.connection_id,
1488                    proto::UpdateWorktreeSettings {
1489                        project_id: project.id.to_proto(),
1490                        worktree_id: worktree.id,
1491                        path: settings_file.path,
1492                        content: Some(settings_file.content),
1493                        kind: Some(settings_file.kind.to_proto().into()),
1494                    },
1495                )?;
1496            }
1497        }
1498
1499        for repository in mem::take(&mut project.updated_repositories) {
1500            for update in split_repository_update(repository) {
1501                session.peer.send(session.connection_id, update)?;
1502            }
1503        }
1504
1505        for id in mem::take(&mut project.removed_repositories) {
1506            session.peer.send(
1507                session.connection_id,
1508                proto::RemoveRepository {
1509                    project_id: project.id.to_proto(),
1510                    id,
1511                },
1512            )?;
1513        }
1514    }
1515
1516    Ok(())
1517}
1518
1519/// leave room disconnects from the room.
1520async fn leave_room(
1521    _: proto::LeaveRoom,
1522    response: Response<proto::LeaveRoom>,
1523    session: Session,
1524) -> Result<()> {
1525    leave_room_for_session(&session, session.connection_id).await?;
1526    response.send(proto::Ack {})?;
1527    Ok(())
1528}
1529
1530/// Updates the permissions of someone else in the room.
1531async fn set_room_participant_role(
1532    request: proto::SetRoomParticipantRole,
1533    response: Response<proto::SetRoomParticipantRole>,
1534    session: Session,
1535) -> Result<()> {
1536    let user_id = UserId::from_proto(request.user_id);
1537    let role = ChannelRole::from(request.role());
1538
1539    let (livekit_room, can_publish) = {
1540        let room = session
1541            .db()
1542            .await
1543            .set_room_participant_role(
1544                session.user_id(),
1545                RoomId::from_proto(request.room_id),
1546                user_id,
1547                role,
1548            )
1549            .await?;
1550
1551        let livekit_room = room.livekit_room.clone();
1552        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1553        room_updated(&room, &session.peer);
1554        (livekit_room, can_publish)
1555    };
1556
1557    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1558        live_kit
1559            .update_participant(
1560                livekit_room.clone(),
1561                request.user_id.to_string(),
1562                livekit_api::proto::ParticipantPermission {
1563                    can_subscribe: true,
1564                    can_publish,
1565                    can_publish_data: can_publish,
1566                    hidden: false,
1567                    recorder: false,
1568                },
1569            )
1570            .await
1571            .trace_err();
1572    }
1573
1574    response.send(proto::Ack {})?;
1575    Ok(())
1576}
1577
1578/// Call someone else into the current room
1579async fn call(
1580    request: proto::Call,
1581    response: Response<proto::Call>,
1582    session: Session,
1583) -> Result<()> {
1584    let room_id = RoomId::from_proto(request.room_id);
1585    let calling_user_id = session.user_id();
1586    let calling_connection_id = session.connection_id;
1587    let called_user_id = UserId::from_proto(request.called_user_id);
1588    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1589    if !session
1590        .db()
1591        .await
1592        .has_contact(calling_user_id, called_user_id)
1593        .await?
1594    {
1595        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1596    }
1597
1598    let incoming_call = {
1599        let (room, incoming_call) = &mut *session
1600            .db()
1601            .await
1602            .call(
1603                room_id,
1604                calling_user_id,
1605                calling_connection_id,
1606                called_user_id,
1607                initial_project_id,
1608            )
1609            .await?;
1610        room_updated(room, &session.peer);
1611        mem::take(incoming_call)
1612    };
1613    update_user_contacts(called_user_id, &session).await?;
1614
1615    let mut calls = session
1616        .connection_pool()
1617        .await
1618        .user_connection_ids(called_user_id)
1619        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1620        .collect::<FuturesUnordered<_>>();
1621
1622    while let Some(call_response) = calls.next().await {
1623        match call_response.as_ref() {
1624            Ok(_) => {
1625                response.send(proto::Ack {})?;
1626                return Ok(());
1627            }
1628            Err(_) => {
1629                call_response.trace_err();
1630            }
1631        }
1632    }
1633
1634    {
1635        let room = session
1636            .db()
1637            .await
1638            .call_failed(room_id, called_user_id)
1639            .await?;
1640        room_updated(&room, &session.peer);
1641    }
1642    update_user_contacts(called_user_id, &session).await?;
1643
1644    Err(anyhow!("failed to ring user"))?
1645}
1646
1647/// Cancel an outgoing call.
1648async fn cancel_call(
1649    request: proto::CancelCall,
1650    response: Response<proto::CancelCall>,
1651    session: Session,
1652) -> Result<()> {
1653    let called_user_id = UserId::from_proto(request.called_user_id);
1654    let room_id = RoomId::from_proto(request.room_id);
1655    {
1656        let room = session
1657            .db()
1658            .await
1659            .cancel_call(room_id, session.connection_id, called_user_id)
1660            .await?;
1661        room_updated(&room, &session.peer);
1662    }
1663
1664    for connection_id in session
1665        .connection_pool()
1666        .await
1667        .user_connection_ids(called_user_id)
1668    {
1669        session
1670            .peer
1671            .send(
1672                connection_id,
1673                proto::CallCanceled {
1674                    room_id: room_id.to_proto(),
1675                },
1676            )
1677            .trace_err();
1678    }
1679    response.send(proto::Ack {})?;
1680
1681    update_user_contacts(called_user_id, &session).await?;
1682    Ok(())
1683}
1684
1685/// Decline an incoming call.
1686async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1687    let room_id = RoomId::from_proto(message.room_id);
1688    {
1689        let room = session
1690            .db()
1691            .await
1692            .decline_call(Some(room_id), session.user_id())
1693            .await?
1694            .context("declining call")?;
1695        room_updated(&room, &session.peer);
1696    }
1697
1698    for connection_id in session
1699        .connection_pool()
1700        .await
1701        .user_connection_ids(session.user_id())
1702    {
1703        session
1704            .peer
1705            .send(
1706                connection_id,
1707                proto::CallCanceled {
1708                    room_id: room_id.to_proto(),
1709                },
1710            )
1711            .trace_err();
1712    }
1713    update_user_contacts(session.user_id(), &session).await?;
1714    Ok(())
1715}
1716
1717/// Updates other participants in the room with your current location.
1718async fn update_participant_location(
1719    request: proto::UpdateParticipantLocation,
1720    response: Response<proto::UpdateParticipantLocation>,
1721    session: Session,
1722) -> Result<()> {
1723    let room_id = RoomId::from_proto(request.room_id);
1724    let location = request.location.context("invalid location")?;
1725
1726    let db = session.db().await;
1727    let room = db
1728        .update_room_participant_location(room_id, session.connection_id, location)
1729        .await?;
1730
1731    room_updated(&room, &session.peer);
1732    response.send(proto::Ack {})?;
1733    Ok(())
1734}
1735
1736/// Share a project into the room.
1737async fn share_project(
1738    request: proto::ShareProject,
1739    response: Response<proto::ShareProject>,
1740    session: Session,
1741) -> Result<()> {
1742    let (project_id, room) = &*session
1743        .db()
1744        .await
1745        .share_project(
1746            RoomId::from_proto(request.room_id),
1747            session.connection_id,
1748            &request.worktrees,
1749            request.is_ssh_project,
1750        )
1751        .await?;
1752    response.send(proto::ShareProjectResponse {
1753        project_id: project_id.to_proto(),
1754    })?;
1755    room_updated(room, &session.peer);
1756
1757    Ok(())
1758}
1759
1760/// Unshare a project from the room.
1761async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1762    let project_id = ProjectId::from_proto(message.project_id);
1763    unshare_project_internal(project_id, session.connection_id, &session).await
1764}
1765
1766async fn unshare_project_internal(
1767    project_id: ProjectId,
1768    connection_id: ConnectionId,
1769    session: &Session,
1770) -> Result<()> {
1771    let delete = {
1772        let room_guard = session
1773            .db()
1774            .await
1775            .unshare_project(project_id, connection_id)
1776            .await?;
1777
1778        let (delete, room, guest_connection_ids) = &*room_guard;
1779
1780        let message = proto::UnshareProject {
1781            project_id: project_id.to_proto(),
1782        };
1783
1784        broadcast(
1785            Some(connection_id),
1786            guest_connection_ids.iter().copied(),
1787            |conn_id| session.peer.send(conn_id, message.clone()),
1788        );
1789        if let Some(room) = room {
1790            room_updated(room, &session.peer);
1791        }
1792
1793        *delete
1794    };
1795
1796    if delete {
1797        let db = session.db().await;
1798        db.delete_project(project_id).await?;
1799    }
1800
1801    Ok(())
1802}
1803
1804/// Join someone elses shared project.
1805async fn join_project(
1806    request: proto::JoinProject,
1807    response: Response<proto::JoinProject>,
1808    session: Session,
1809) -> Result<()> {
1810    let project_id = ProjectId::from_proto(request.project_id);
1811
1812    tracing::info!(%project_id, "join project");
1813
1814    let db = session.db().await;
1815    let (project, replica_id) = &mut *db
1816        .join_project(project_id, session.connection_id, session.user_id())
1817        .await?;
1818    drop(db);
1819    tracing::info!(%project_id, "join remote project");
1820    join_project_internal(response, session, project, replica_id)
1821}
1822
1823trait JoinProjectInternalResponse {
1824    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1827    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828        Response::<proto::JoinProject>::send(self, result)
1829    }
1830}
1831
1832fn join_project_internal(
1833    response: impl JoinProjectInternalResponse,
1834    session: Session,
1835    project: &mut Project,
1836    replica_id: &ReplicaId,
1837) -> Result<()> {
1838    let collaborators = project
1839        .collaborators
1840        .iter()
1841        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842        .map(|collaborator| collaborator.to_proto())
1843        .collect::<Vec<_>>();
1844    let project_id = project.id;
1845    let guest_user_id = session.user_id();
1846
1847    let worktrees = project
1848        .worktrees
1849        .iter()
1850        .map(|(id, worktree)| proto::WorktreeMetadata {
1851            id: *id,
1852            root_name: worktree.root_name.clone(),
1853            visible: worktree.visible,
1854            abs_path: worktree.abs_path.clone(),
1855        })
1856        .collect::<Vec<_>>();
1857
1858    let add_project_collaborator = proto::AddProjectCollaborator {
1859        project_id: project_id.to_proto(),
1860        collaborator: Some(proto::Collaborator {
1861            peer_id: Some(session.connection_id.into()),
1862            replica_id: replica_id.0 as u32,
1863            user_id: guest_user_id.to_proto(),
1864            is_host: false,
1865        }),
1866    };
1867
1868    for collaborator in &collaborators {
1869        session
1870            .peer
1871            .send(
1872                collaborator.peer_id.unwrap().into(),
1873                add_project_collaborator.clone(),
1874            )
1875            .trace_err();
1876    }
1877
1878    // First, we send the metadata associated with each worktree.
1879    response.send(proto::JoinProjectResponse {
1880        project_id: project.id.0 as u64,
1881        worktrees: worktrees.clone(),
1882        replica_id: replica_id.0 as u32,
1883        collaborators: collaborators.clone(),
1884        language_servers: project.language_servers.clone(),
1885        role: project.role.into(),
1886    })?;
1887
1888    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1889        // Stream this worktree's entries.
1890        let message = proto::UpdateWorktree {
1891            project_id: project_id.to_proto(),
1892            worktree_id,
1893            abs_path: worktree.abs_path.clone(),
1894            root_name: worktree.root_name,
1895            updated_entries: worktree.entries,
1896            removed_entries: Default::default(),
1897            scan_id: worktree.scan_id,
1898            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1899            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1900            removed_repositories: Default::default(),
1901        };
1902        for update in proto::split_worktree_update(message) {
1903            session.peer.send(session.connection_id, update.clone())?;
1904        }
1905
1906        // Stream this worktree's diagnostics.
1907        for summary in worktree.diagnostic_summaries {
1908            session.peer.send(
1909                session.connection_id,
1910                proto::UpdateDiagnosticSummary {
1911                    project_id: project_id.to_proto(),
1912                    worktree_id: worktree.id,
1913                    summary: Some(summary),
1914                },
1915            )?;
1916        }
1917
1918        for settings_file in worktree.settings_files {
1919            session.peer.send(
1920                session.connection_id,
1921                proto::UpdateWorktreeSettings {
1922                    project_id: project_id.to_proto(),
1923                    worktree_id: worktree.id,
1924                    path: settings_file.path,
1925                    content: Some(settings_file.content),
1926                    kind: Some(settings_file.kind.to_proto() as i32),
1927                },
1928            )?;
1929        }
1930    }
1931
1932    for repository in mem::take(&mut project.repositories) {
1933        for update in split_repository_update(repository) {
1934            session.peer.send(session.connection_id, update)?;
1935        }
1936    }
1937
1938    for language_server in &project.language_servers {
1939        session.peer.send(
1940            session.connection_id,
1941            proto::UpdateLanguageServer {
1942                project_id: project_id.to_proto(),
1943                language_server_id: language_server.id,
1944                variant: Some(
1945                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1946                        proto::LspDiskBasedDiagnosticsUpdated {},
1947                    ),
1948                ),
1949            },
1950        )?;
1951    }
1952
1953    Ok(())
1954}
1955
1956/// Leave someone elses shared project.
1957async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1958    let sender_id = session.connection_id;
1959    let project_id = ProjectId::from_proto(request.project_id);
1960    let db = session.db().await;
1961
1962    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963    tracing::info!(
1964        %project_id,
1965        "leave project"
1966    );
1967
1968    project_left(project, &session);
1969    if let Some(room) = room {
1970        room_updated(room, &session.peer);
1971    }
1972
1973    Ok(())
1974}
1975
1976/// Updates other participants with changes to the project
1977async fn update_project(
1978    request: proto::UpdateProject,
1979    response: Response<proto::UpdateProject>,
1980    session: Session,
1981) -> Result<()> {
1982    let project_id = ProjectId::from_proto(request.project_id);
1983    let (room, guest_connection_ids) = &*session
1984        .db()
1985        .await
1986        .update_project(project_id, session.connection_id, &request.worktrees)
1987        .await?;
1988    broadcast(
1989        Some(session.connection_id),
1990        guest_connection_ids.iter().copied(),
1991        |connection_id| {
1992            session
1993                .peer
1994                .forward_send(session.connection_id, connection_id, request.clone())
1995        },
1996    );
1997    if let Some(room) = room {
1998        room_updated(room, &session.peer);
1999    }
2000    response.send(proto::Ack {})?;
2001
2002    Ok(())
2003}
2004
2005/// Updates other participants with changes to the worktree
2006async fn update_worktree(
2007    request: proto::UpdateWorktree,
2008    response: Response<proto::UpdateWorktree>,
2009    session: Session,
2010) -> Result<()> {
2011    let guest_connection_ids = session
2012        .db()
2013        .await
2014        .update_worktree(&request, session.connection_id)
2015        .await?;
2016
2017    broadcast(
2018        Some(session.connection_id),
2019        guest_connection_ids.iter().copied(),
2020        |connection_id| {
2021            session
2022                .peer
2023                .forward_send(session.connection_id, connection_id, request.clone())
2024        },
2025    );
2026    response.send(proto::Ack {})?;
2027    Ok(())
2028}
2029
2030async fn update_repository(
2031    request: proto::UpdateRepository,
2032    response: Response<proto::UpdateRepository>,
2033    session: Session,
2034) -> Result<()> {
2035    let guest_connection_ids = session
2036        .db()
2037        .await
2038        .update_repository(&request, session.connection_id)
2039        .await?;
2040
2041    broadcast(
2042        Some(session.connection_id),
2043        guest_connection_ids.iter().copied(),
2044        |connection_id| {
2045            session
2046                .peer
2047                .forward_send(session.connection_id, connection_id, request.clone())
2048        },
2049    );
2050    response.send(proto::Ack {})?;
2051    Ok(())
2052}
2053
2054async fn remove_repository(
2055    request: proto::RemoveRepository,
2056    response: Response<proto::RemoveRepository>,
2057    session: Session,
2058) -> Result<()> {
2059    let guest_connection_ids = session
2060        .db()
2061        .await
2062        .remove_repository(&request, session.connection_id)
2063        .await?;
2064
2065    broadcast(
2066        Some(session.connection_id),
2067        guest_connection_ids.iter().copied(),
2068        |connection_id| {
2069            session
2070                .peer
2071                .forward_send(session.connection_id, connection_id, request.clone())
2072        },
2073    );
2074    response.send(proto::Ack {})?;
2075    Ok(())
2076}
2077
2078/// Updates other participants with changes to the diagnostics
2079async fn update_diagnostic_summary(
2080    message: proto::UpdateDiagnosticSummary,
2081    session: Session,
2082) -> Result<()> {
2083    let guest_connection_ids = session
2084        .db()
2085        .await
2086        .update_diagnostic_summary(&message, session.connection_id)
2087        .await?;
2088
2089    broadcast(
2090        Some(session.connection_id),
2091        guest_connection_ids.iter().copied(),
2092        |connection_id| {
2093            session
2094                .peer
2095                .forward_send(session.connection_id, connection_id, message.clone())
2096        },
2097    );
2098
2099    Ok(())
2100}
2101
2102/// Updates other participants with changes to the worktree settings
2103async fn update_worktree_settings(
2104    message: proto::UpdateWorktreeSettings,
2105    session: Session,
2106) -> Result<()> {
2107    let guest_connection_ids = session
2108        .db()
2109        .await
2110        .update_worktree_settings(&message, session.connection_id)
2111        .await?;
2112
2113    broadcast(
2114        Some(session.connection_id),
2115        guest_connection_ids.iter().copied(),
2116        |connection_id| {
2117            session
2118                .peer
2119                .forward_send(session.connection_id, connection_id, message.clone())
2120        },
2121    );
2122
2123    Ok(())
2124}
2125
2126/// Notify other participants that a language server has started.
2127async fn start_language_server(
2128    request: proto::StartLanguageServer,
2129    session: Session,
2130) -> Result<()> {
2131    let guest_connection_ids = session
2132        .db()
2133        .await
2134        .start_language_server(&request, session.connection_id)
2135        .await?;
2136
2137    broadcast(
2138        Some(session.connection_id),
2139        guest_connection_ids.iter().copied(),
2140        |connection_id| {
2141            session
2142                .peer
2143                .forward_send(session.connection_id, connection_id, request.clone())
2144        },
2145    );
2146    Ok(())
2147}
2148
2149/// Notify other participants that a language server has changed.
2150async fn update_language_server(
2151    request: proto::UpdateLanguageServer,
2152    session: Session,
2153) -> Result<()> {
2154    let project_id = ProjectId::from_proto(request.project_id);
2155    let project_connection_ids = session
2156        .db()
2157        .await
2158        .project_connection_ids(project_id, session.connection_id, true)
2159        .await?;
2160    broadcast(
2161        Some(session.connection_id),
2162        project_connection_ids.iter().copied(),
2163        |connection_id| {
2164            session
2165                .peer
2166                .forward_send(session.connection_id, connection_id, request.clone())
2167        },
2168    );
2169    Ok(())
2170}
2171
2172/// forward a project request to the host. These requests should be read only
2173/// as guests are allowed to send them.
2174async fn forward_read_only_project_request<T>(
2175    request: T,
2176    response: Response<T>,
2177    session: Session,
2178) -> Result<()>
2179where
2180    T: EntityMessage + RequestMessage,
2181{
2182    let project_id = ProjectId::from_proto(request.remote_entity_id());
2183    let host_connection_id = session
2184        .db()
2185        .await
2186        .host_for_read_only_project_request(project_id, session.connection_id)
2187        .await?;
2188    let payload = session
2189        .peer
2190        .forward_request(session.connection_id, host_connection_id, request)
2191        .await?;
2192    response.send(payload)?;
2193    Ok(())
2194}
2195
2196async fn forward_find_search_candidates_request(
2197    request: proto::FindSearchCandidates,
2198    response: Response<proto::FindSearchCandidates>,
2199    session: Session,
2200) -> Result<()> {
2201    let project_id = ProjectId::from_proto(request.remote_entity_id());
2202    let host_connection_id = session
2203        .db()
2204        .await
2205        .host_for_read_only_project_request(project_id, session.connection_id)
2206        .await?;
2207    let payload = session
2208        .peer
2209        .forward_request(session.connection_id, host_connection_id, request)
2210        .await?;
2211    response.send(payload)?;
2212    Ok(())
2213}
2214
2215/// forward a project request to the host. These requests are disallowed
2216/// for guests.
2217async fn forward_mutating_project_request<T>(
2218    request: T,
2219    response: Response<T>,
2220    session: Session,
2221) -> Result<()>
2222where
2223    T: EntityMessage + RequestMessage,
2224{
2225    let project_id = ProjectId::from_proto(request.remote_entity_id());
2226
2227    let host_connection_id = session
2228        .db()
2229        .await
2230        .host_for_mutating_project_request(project_id, session.connection_id)
2231        .await?;
2232    let payload = session
2233        .peer
2234        .forward_request(session.connection_id, host_connection_id, request)
2235        .await?;
2236    response.send(payload)?;
2237    Ok(())
2238}
2239
2240/// Notify other participants that a new buffer has been created
2241async fn create_buffer_for_peer(
2242    request: proto::CreateBufferForPeer,
2243    session: Session,
2244) -> Result<()> {
2245    session
2246        .db()
2247        .await
2248        .check_user_is_project_host(
2249            ProjectId::from_proto(request.project_id),
2250            session.connection_id,
2251        )
2252        .await?;
2253    let peer_id = request.peer_id.context("invalid peer id")?;
2254    session
2255        .peer
2256        .forward_send(session.connection_id, peer_id.into(), request)?;
2257    Ok(())
2258}
2259
2260/// Notify other participants that a buffer has been updated. This is
2261/// allowed for guests as long as the update is limited to selections.
2262async fn update_buffer(
2263    request: proto::UpdateBuffer,
2264    response: Response<proto::UpdateBuffer>,
2265    session: Session,
2266) -> Result<()> {
2267    let project_id = ProjectId::from_proto(request.project_id);
2268    let mut capability = Capability::ReadOnly;
2269
2270    for op in request.operations.iter() {
2271        match op.variant {
2272            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2273            Some(_) => capability = Capability::ReadWrite,
2274        }
2275    }
2276
2277    let host = {
2278        let guard = session
2279            .db()
2280            .await
2281            .connections_for_buffer_update(project_id, session.connection_id, capability)
2282            .await?;
2283
2284        let (host, guests) = &*guard;
2285
2286        broadcast(
2287            Some(session.connection_id),
2288            guests.clone(),
2289            |connection_id| {
2290                session
2291                    .peer
2292                    .forward_send(session.connection_id, connection_id, request.clone())
2293            },
2294        );
2295
2296        *host
2297    };
2298
2299    if host != session.connection_id {
2300        session
2301            .peer
2302            .forward_request(session.connection_id, host, request.clone())
2303            .await?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2311    let project_id = ProjectId::from_proto(message.project_id);
2312
2313    let operation = message.operation.as_ref().context("invalid operation")?;
2314    let capability = match operation.variant.as_ref() {
2315        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2316            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2317                match buffer_op.variant {
2318                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2319                        Capability::ReadOnly
2320                    }
2321                    _ => Capability::ReadWrite,
2322                }
2323            } else {
2324                Capability::ReadWrite
2325            }
2326        }
2327        Some(_) => Capability::ReadWrite,
2328        None => Capability::ReadOnly,
2329    };
2330
2331    let guard = session
2332        .db()
2333        .await
2334        .connections_for_buffer_update(project_id, session.connection_id, capability)
2335        .await?;
2336
2337    let (host, guests) = &*guard;
2338
2339    broadcast(
2340        Some(session.connection_id),
2341        guests.iter().chain([host]).copied(),
2342        |connection_id| {
2343            session
2344                .peer
2345                .forward_send(session.connection_id, connection_id, message.clone())
2346        },
2347    );
2348
2349    Ok(())
2350}
2351
2352/// Notify other participants that a project has been updated.
2353async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2354    request: T,
2355    session: Session,
2356) -> Result<()> {
2357    let project_id = ProjectId::from_proto(request.remote_entity_id());
2358    let project_connection_ids = session
2359        .db()
2360        .await
2361        .project_connection_ids(project_id, session.connection_id, false)
2362        .await?;
2363
2364    broadcast(
2365        Some(session.connection_id),
2366        project_connection_ids.iter().copied(),
2367        |connection_id| {
2368            session
2369                .peer
2370                .forward_send(session.connection_id, connection_id, request.clone())
2371        },
2372    );
2373    Ok(())
2374}
2375
2376/// Start following another user in a call.
2377async fn follow(
2378    request: proto::Follow,
2379    response: Response<proto::Follow>,
2380    session: Session,
2381) -> Result<()> {
2382    let room_id = RoomId::from_proto(request.room_id);
2383    let project_id = request.project_id.map(ProjectId::from_proto);
2384    let leader_id = request.leader_id.context("invalid leader id")?.into();
2385    let follower_id = session.connection_id;
2386
2387    session
2388        .db()
2389        .await
2390        .check_room_participants(room_id, leader_id, session.connection_id)
2391        .await?;
2392
2393    let response_payload = session
2394        .peer
2395        .forward_request(session.connection_id, leader_id, request)
2396        .await?;
2397    response.send(response_payload)?;
2398
2399    if let Some(project_id) = project_id {
2400        let room = session
2401            .db()
2402            .await
2403            .follow(room_id, project_id, leader_id, follower_id)
2404            .await?;
2405        room_updated(&room, &session.peer);
2406    }
2407
2408    Ok(())
2409}
2410
2411/// Stop following another user in a call.
2412async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2413    let room_id = RoomId::from_proto(request.room_id);
2414    let project_id = request.project_id.map(ProjectId::from_proto);
2415    let leader_id = request.leader_id.context("invalid leader id")?.into();
2416    let follower_id = session.connection_id;
2417
2418    session
2419        .db()
2420        .await
2421        .check_room_participants(room_id, leader_id, session.connection_id)
2422        .await?;
2423
2424    session
2425        .peer
2426        .forward_send(session.connection_id, leader_id, request)?;
2427
2428    if let Some(project_id) = project_id {
2429        let room = session
2430            .db()
2431            .await
2432            .unfollow(room_id, project_id, leader_id, follower_id)
2433            .await?;
2434        room_updated(&room, &session.peer);
2435    }
2436
2437    Ok(())
2438}
2439
2440/// Notify everyone following you of your current location.
2441async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2442    let room_id = RoomId::from_proto(request.room_id);
2443    let database = session.db.lock().await;
2444
2445    let connection_ids = if let Some(project_id) = request.project_id {
2446        let project_id = ProjectId::from_proto(project_id);
2447        database
2448            .project_connection_ids(project_id, session.connection_id, true)
2449            .await?
2450    } else {
2451        database
2452            .room_connection_ids(room_id, session.connection_id)
2453            .await?
2454    };
2455
2456    // For now, don't send view update messages back to that view's current leader.
2457    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2458        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2459        _ => None,
2460    });
2461
2462    for connection_id in connection_ids.iter().cloned() {
2463        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2464            session
2465                .peer
2466                .forward_send(session.connection_id, connection_id, request.clone())?;
2467        }
2468    }
2469    Ok(())
2470}
2471
2472/// Get public data about users.
2473async fn get_users(
2474    request: proto::GetUsers,
2475    response: Response<proto::GetUsers>,
2476    session: Session,
2477) -> Result<()> {
2478    let user_ids = request
2479        .user_ids
2480        .into_iter()
2481        .map(UserId::from_proto)
2482        .collect();
2483    let users = session
2484        .db()
2485        .await
2486        .get_users_by_ids(user_ids)
2487        .await?
2488        .into_iter()
2489        .map(|user| proto::User {
2490            id: user.id.to_proto(),
2491            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2492            github_login: user.github_login,
2493            email: user.email_address,
2494            name: user.name,
2495        })
2496        .collect();
2497    response.send(proto::UsersResponse { users })?;
2498    Ok(())
2499}
2500
2501/// Search for users (to invite) buy Github login
2502async fn fuzzy_search_users(
2503    request: proto::FuzzySearchUsers,
2504    response: Response<proto::FuzzySearchUsers>,
2505    session: Session,
2506) -> Result<()> {
2507    let query = request.query;
2508    let users = match query.len() {
2509        0 => vec![],
2510        1 | 2 => session
2511            .db()
2512            .await
2513            .get_user_by_github_login(&query)
2514            .await?
2515            .into_iter()
2516            .collect(),
2517        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2518    };
2519    let users = users
2520        .into_iter()
2521        .filter(|user| user.id != session.user_id())
2522        .map(|user| proto::User {
2523            id: user.id.to_proto(),
2524            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2525            github_login: user.github_login,
2526            name: user.name,
2527            email: user.email_address,
2528        })
2529        .collect();
2530    response.send(proto::UsersResponse { users })?;
2531    Ok(())
2532}
2533
2534/// Send a contact request to another user.
2535async fn request_contact(
2536    request: proto::RequestContact,
2537    response: Response<proto::RequestContact>,
2538    session: Session,
2539) -> Result<()> {
2540    let requester_id = session.user_id();
2541    let responder_id = UserId::from_proto(request.responder_id);
2542    if requester_id == responder_id {
2543        return Err(anyhow!("cannot add yourself as a contact"))?;
2544    }
2545
2546    let notifications = session
2547        .db()
2548        .await
2549        .send_contact_request(requester_id, responder_id)
2550        .await?;
2551
2552    // Update outgoing contact requests of requester
2553    let mut update = proto::UpdateContacts::default();
2554    update.outgoing_requests.push(responder_id.to_proto());
2555    for connection_id in session
2556        .connection_pool()
2557        .await
2558        .user_connection_ids(requester_id)
2559    {
2560        session.peer.send(connection_id, update.clone())?;
2561    }
2562
2563    // Update incoming contact requests of responder
2564    let mut update = proto::UpdateContacts::default();
2565    update
2566        .incoming_requests
2567        .push(proto::IncomingContactRequest {
2568            requester_id: requester_id.to_proto(),
2569        });
2570    let connection_pool = session.connection_pool().await;
2571    for connection_id in connection_pool.user_connection_ids(responder_id) {
2572        session.peer.send(connection_id, update.clone())?;
2573    }
2574
2575    send_notifications(&connection_pool, &session.peer, notifications);
2576
2577    response.send(proto::Ack {})?;
2578    Ok(())
2579}
2580
2581/// Accept or decline a contact request
2582async fn respond_to_contact_request(
2583    request: proto::RespondToContactRequest,
2584    response: Response<proto::RespondToContactRequest>,
2585    session: Session,
2586) -> Result<()> {
2587    let responder_id = session.user_id();
2588    let requester_id = UserId::from_proto(request.requester_id);
2589    let db = session.db().await;
2590    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2591        db.dismiss_contact_notification(responder_id, requester_id)
2592            .await?;
2593    } else {
2594        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2595
2596        let notifications = db
2597            .respond_to_contact_request(responder_id, requester_id, accept)
2598            .await?;
2599        let requester_busy = db.is_user_busy(requester_id).await?;
2600        let responder_busy = db.is_user_busy(responder_id).await?;
2601
2602        let pool = session.connection_pool().await;
2603        // Update responder with new contact
2604        let mut update = proto::UpdateContacts::default();
2605        if accept {
2606            update
2607                .contacts
2608                .push(contact_for_user(requester_id, requester_busy, &pool));
2609        }
2610        update
2611            .remove_incoming_requests
2612            .push(requester_id.to_proto());
2613        for connection_id in pool.user_connection_ids(responder_id) {
2614            session.peer.send(connection_id, update.clone())?;
2615        }
2616
2617        // Update requester with new contact
2618        let mut update = proto::UpdateContacts::default();
2619        if accept {
2620            update
2621                .contacts
2622                .push(contact_for_user(responder_id, responder_busy, &pool));
2623        }
2624        update
2625            .remove_outgoing_requests
2626            .push(responder_id.to_proto());
2627
2628        for connection_id in pool.user_connection_ids(requester_id) {
2629            session.peer.send(connection_id, update.clone())?;
2630        }
2631
2632        send_notifications(&pool, &session.peer, notifications);
2633    }
2634
2635    response.send(proto::Ack {})?;
2636    Ok(())
2637}
2638
2639/// Remove a contact.
2640async fn remove_contact(
2641    request: proto::RemoveContact,
2642    response: Response<proto::RemoveContact>,
2643    session: Session,
2644) -> Result<()> {
2645    let requester_id = session.user_id();
2646    let responder_id = UserId::from_proto(request.user_id);
2647    let db = session.db().await;
2648    let (contact_accepted, deleted_notification_id) =
2649        db.remove_contact(requester_id, responder_id).await?;
2650
2651    let pool = session.connection_pool().await;
2652    // Update outgoing contact requests of requester
2653    let mut update = proto::UpdateContacts::default();
2654    if contact_accepted {
2655        update.remove_contacts.push(responder_id.to_proto());
2656    } else {
2657        update
2658            .remove_outgoing_requests
2659            .push(responder_id.to_proto());
2660    }
2661    for connection_id in pool.user_connection_ids(requester_id) {
2662        session.peer.send(connection_id, update.clone())?;
2663    }
2664
2665    // Update incoming contact requests of responder
2666    let mut update = proto::UpdateContacts::default();
2667    if contact_accepted {
2668        update.remove_contacts.push(requester_id.to_proto());
2669    } else {
2670        update
2671            .remove_incoming_requests
2672            .push(requester_id.to_proto());
2673    }
2674    for connection_id in pool.user_connection_ids(responder_id) {
2675        session.peer.send(connection_id, update.clone())?;
2676        if let Some(notification_id) = deleted_notification_id {
2677            session.peer.send(
2678                connection_id,
2679                proto::DeleteNotification {
2680                    notification_id: notification_id.to_proto(),
2681                },
2682            )?;
2683        }
2684    }
2685
2686    response.send(proto::Ack {})?;
2687    Ok(())
2688}
2689
2690fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2691    version.0.minor() < 139
2692}
2693
2694async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2695    if is_staff {
2696        return Ok(proto::Plan::ZedPro);
2697    }
2698
2699    let subscription = db.get_active_billing_subscription(user_id).await?;
2700    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2701
2702    let plan = if let Some(subscription_kind) = subscription_kind {
2703        match subscription_kind {
2704            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2705            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2706            SubscriptionKind::ZedFree => proto::Plan::Free,
2707        }
2708    } else {
2709        proto::Plan::Free
2710    };
2711
2712    Ok(plan)
2713}
2714
2715async fn make_update_user_plan_message(
2716    user: &User,
2717    is_staff: bool,
2718    db: &Arc<Database>,
2719    llm_db: Option<Arc<LlmDatabase>>,
2720) -> Result<proto::UpdateUserPlan> {
2721    let feature_flags = db.get_user_flags(user.id).await?;
2722    let plan = current_plan(db, user.id, is_staff).await?;
2723    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2724    let billing_preferences = db.get_billing_preferences(user.id).await?;
2725
2726    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2727        let subscription = db.get_active_billing_subscription(user.id).await?;
2728
2729        let subscription_period =
2730            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2731
2732        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2733            llm_db
2734                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2735                .await?
2736        } else {
2737            None
2738        };
2739
2740        (subscription_period, usage)
2741    } else {
2742        (None, None)
2743    };
2744
2745    let account_too_young =
2746        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2747
2748    Ok(proto::UpdateUserPlan {
2749        plan: plan.into(),
2750        trial_started_at: billing_customer
2751            .as_ref()
2752            .and_then(|billing_customer| billing_customer.trial_started_at)
2753            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2754        is_usage_based_billing_enabled: if is_staff {
2755            Some(true)
2756        } else {
2757            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2758        },
2759        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2760            proto::SubscriptionPeriod {
2761                started_at: started_at.timestamp() as u64,
2762                ended_at: ended_at.timestamp() as u64,
2763            }
2764        }),
2765        account_too_young: Some(account_too_young),
2766        has_overdue_invoices: billing_customer
2767            .map(|billing_customer| billing_customer.has_overdue_invoices),
2768        usage: usage.map(|usage| {
2769            let plan = match plan {
2770                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2771                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2772                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2773            };
2774
2775            let model_requests_limit = match plan.model_requests_limit() {
2776                zed_llm_client::UsageLimit::Limited(limit) => {
2777                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2778                        && feature_flags
2779                            .iter()
2780                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2781                    {
2782                        1_000
2783                    } else {
2784                        limit
2785                    };
2786
2787                    zed_llm_client::UsageLimit::Limited(limit)
2788                }
2789                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2790            };
2791
2792            proto::SubscriptionUsage {
2793                model_requests_usage_amount: usage.model_requests as u32,
2794                model_requests_usage_limit: Some(proto::UsageLimit {
2795                    variant: Some(match model_requests_limit {
2796                        zed_llm_client::UsageLimit::Limited(limit) => {
2797                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2798                                limit: limit as u32,
2799                            })
2800                        }
2801                        zed_llm_client::UsageLimit::Unlimited => {
2802                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2803                        }
2804                    }),
2805                }),
2806                edit_predictions_usage_amount: usage.edit_predictions as u32,
2807                edit_predictions_usage_limit: Some(proto::UsageLimit {
2808                    variant: Some(match plan.edit_predictions_limit() {
2809                        zed_llm_client::UsageLimit::Limited(limit) => {
2810                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2811                                limit: limit as u32,
2812                            })
2813                        }
2814                        zed_llm_client::UsageLimit::Unlimited => {
2815                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2816                        }
2817                    }),
2818                }),
2819            }
2820        }),
2821    })
2822}
2823
2824async fn update_user_plan(session: &Session) -> Result<()> {
2825    let db = session.db().await;
2826
2827    let update_user_plan = make_update_user_plan_message(
2828        session.principal.user(),
2829        session.is_staff(),
2830        &db.0,
2831        session.app_state.llm_db.clone(),
2832    )
2833    .await?;
2834
2835    session
2836        .peer
2837        .send(session.connection_id, update_user_plan)
2838        .trace_err();
2839
2840    Ok(())
2841}
2842
2843async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2844    subscribe_user_to_channels(session.user_id(), &session).await?;
2845    Ok(())
2846}
2847
2848async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2849    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2850    let mut pool = session.connection_pool().await;
2851    for membership in &channels_for_user.channel_memberships {
2852        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2853    }
2854    session.peer.send(
2855        session.connection_id,
2856        build_update_user_channels(&channels_for_user),
2857    )?;
2858    session.peer.send(
2859        session.connection_id,
2860        build_channels_update(channels_for_user),
2861    )?;
2862    Ok(())
2863}
2864
2865/// Creates a new channel.
2866async fn create_channel(
2867    request: proto::CreateChannel,
2868    response: Response<proto::CreateChannel>,
2869    session: Session,
2870) -> Result<()> {
2871    let db = session.db().await;
2872
2873    let parent_id = request.parent_id.map(ChannelId::from_proto);
2874    let (channel, membership) = db
2875        .create_channel(&request.name, parent_id, session.user_id())
2876        .await?;
2877
2878    let root_id = channel.root_id();
2879    let channel = Channel::from_model(channel);
2880
2881    response.send(proto::CreateChannelResponse {
2882        channel: Some(channel.to_proto()),
2883        parent_id: request.parent_id,
2884    })?;
2885
2886    let mut connection_pool = session.connection_pool().await;
2887    if let Some(membership) = membership {
2888        connection_pool.subscribe_to_channel(
2889            membership.user_id,
2890            membership.channel_id,
2891            membership.role,
2892        );
2893        let update = proto::UpdateUserChannels {
2894            channel_memberships: vec![proto::ChannelMembership {
2895                channel_id: membership.channel_id.to_proto(),
2896                role: membership.role.into(),
2897            }],
2898            ..Default::default()
2899        };
2900        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2901            session.peer.send(connection_id, update.clone())?;
2902        }
2903    }
2904
2905    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2906        if !role.can_see_channel(channel.visibility) {
2907            continue;
2908        }
2909
2910        let update = proto::UpdateChannels {
2911            channels: vec![channel.to_proto()],
2912            ..Default::default()
2913        };
2914        session.peer.send(connection_id, update.clone())?;
2915    }
2916
2917    Ok(())
2918}
2919
2920/// Delete a channel
2921async fn delete_channel(
2922    request: proto::DeleteChannel,
2923    response: Response<proto::DeleteChannel>,
2924    session: Session,
2925) -> Result<()> {
2926    let db = session.db().await;
2927
2928    let channel_id = request.channel_id;
2929    let (root_channel, removed_channels) = db
2930        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2931        .await?;
2932    response.send(proto::Ack {})?;
2933
2934    // Notify members of removed channels
2935    let mut update = proto::UpdateChannels::default();
2936    update
2937        .delete_channels
2938        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2939
2940    let connection_pool = session.connection_pool().await;
2941    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2942        session.peer.send(connection_id, update.clone())?;
2943    }
2944
2945    Ok(())
2946}
2947
2948/// Invite someone to join a channel.
2949async fn invite_channel_member(
2950    request: proto::InviteChannelMember,
2951    response: Response<proto::InviteChannelMember>,
2952    session: Session,
2953) -> Result<()> {
2954    let db = session.db().await;
2955    let channel_id = ChannelId::from_proto(request.channel_id);
2956    let invitee_id = UserId::from_proto(request.user_id);
2957    let InviteMemberResult {
2958        channel,
2959        notifications,
2960    } = db
2961        .invite_channel_member(
2962            channel_id,
2963            invitee_id,
2964            session.user_id(),
2965            request.role().into(),
2966        )
2967        .await?;
2968
2969    let update = proto::UpdateChannels {
2970        channel_invitations: vec![channel.to_proto()],
2971        ..Default::default()
2972    };
2973
2974    let connection_pool = session.connection_pool().await;
2975    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2976        session.peer.send(connection_id, update.clone())?;
2977    }
2978
2979    send_notifications(&connection_pool, &session.peer, notifications);
2980
2981    response.send(proto::Ack {})?;
2982    Ok(())
2983}
2984
2985/// remove someone from a channel
2986async fn remove_channel_member(
2987    request: proto::RemoveChannelMember,
2988    response: Response<proto::RemoveChannelMember>,
2989    session: Session,
2990) -> Result<()> {
2991    let db = session.db().await;
2992    let channel_id = ChannelId::from_proto(request.channel_id);
2993    let member_id = UserId::from_proto(request.user_id);
2994
2995    let RemoveChannelMemberResult {
2996        membership_update,
2997        notification_id,
2998    } = db
2999        .remove_channel_member(channel_id, member_id, session.user_id())
3000        .await?;
3001
3002    let mut connection_pool = session.connection_pool().await;
3003    notify_membership_updated(
3004        &mut connection_pool,
3005        membership_update,
3006        member_id,
3007        &session.peer,
3008    );
3009    for connection_id in connection_pool.user_connection_ids(member_id) {
3010        if let Some(notification_id) = notification_id {
3011            session
3012                .peer
3013                .send(
3014                    connection_id,
3015                    proto::DeleteNotification {
3016                        notification_id: notification_id.to_proto(),
3017                    },
3018                )
3019                .trace_err();
3020        }
3021    }
3022
3023    response.send(proto::Ack {})?;
3024    Ok(())
3025}
3026
3027/// Toggle the channel between public and private.
3028/// Care is taken to maintain the invariant that public channels only descend from public channels,
3029/// (though members-only channels can appear at any point in the hierarchy).
3030async fn set_channel_visibility(
3031    request: proto::SetChannelVisibility,
3032    response: Response<proto::SetChannelVisibility>,
3033    session: Session,
3034) -> Result<()> {
3035    let db = session.db().await;
3036    let channel_id = ChannelId::from_proto(request.channel_id);
3037    let visibility = request.visibility().into();
3038
3039    let channel_model = db
3040        .set_channel_visibility(channel_id, visibility, session.user_id())
3041        .await?;
3042    let root_id = channel_model.root_id();
3043    let channel = Channel::from_model(channel_model);
3044
3045    let mut connection_pool = session.connection_pool().await;
3046    for (user_id, role) in connection_pool
3047        .channel_user_ids(root_id)
3048        .collect::<Vec<_>>()
3049        .into_iter()
3050    {
3051        let update = if role.can_see_channel(channel.visibility) {
3052            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3053            proto::UpdateChannels {
3054                channels: vec![channel.to_proto()],
3055                ..Default::default()
3056            }
3057        } else {
3058            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3059            proto::UpdateChannels {
3060                delete_channels: vec![channel.id.to_proto()],
3061                ..Default::default()
3062            }
3063        };
3064
3065        for connection_id in connection_pool.user_connection_ids(user_id) {
3066            session.peer.send(connection_id, update.clone())?;
3067        }
3068    }
3069
3070    response.send(proto::Ack {})?;
3071    Ok(())
3072}
3073
3074/// Alter the role for a user in the channel.
3075async fn set_channel_member_role(
3076    request: proto::SetChannelMemberRole,
3077    response: Response<proto::SetChannelMemberRole>,
3078    session: Session,
3079) -> Result<()> {
3080    let db = session.db().await;
3081    let channel_id = ChannelId::from_proto(request.channel_id);
3082    let member_id = UserId::from_proto(request.user_id);
3083    let result = db
3084        .set_channel_member_role(
3085            channel_id,
3086            session.user_id(),
3087            member_id,
3088            request.role().into(),
3089        )
3090        .await?;
3091
3092    match result {
3093        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3094            let mut connection_pool = session.connection_pool().await;
3095            notify_membership_updated(
3096                &mut connection_pool,
3097                membership_update,
3098                member_id,
3099                &session.peer,
3100            )
3101        }
3102        db::SetMemberRoleResult::InviteUpdated(channel) => {
3103            let update = proto::UpdateChannels {
3104                channel_invitations: vec![channel.to_proto()],
3105                ..Default::default()
3106            };
3107
3108            for connection_id in session
3109                .connection_pool()
3110                .await
3111                .user_connection_ids(member_id)
3112            {
3113                session.peer.send(connection_id, update.clone())?;
3114            }
3115        }
3116    }
3117
3118    response.send(proto::Ack {})?;
3119    Ok(())
3120}
3121
3122/// Change the name of a channel
3123async fn rename_channel(
3124    request: proto::RenameChannel,
3125    response: Response<proto::RenameChannel>,
3126    session: Session,
3127) -> Result<()> {
3128    let db = session.db().await;
3129    let channel_id = ChannelId::from_proto(request.channel_id);
3130    let channel_model = db
3131        .rename_channel(channel_id, session.user_id(), &request.name)
3132        .await?;
3133    let root_id = channel_model.root_id();
3134    let channel = Channel::from_model(channel_model);
3135
3136    response.send(proto::RenameChannelResponse {
3137        channel: Some(channel.to_proto()),
3138    })?;
3139
3140    let connection_pool = session.connection_pool().await;
3141    let update = proto::UpdateChannels {
3142        channels: vec![channel.to_proto()],
3143        ..Default::default()
3144    };
3145    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146        if role.can_see_channel(channel.visibility) {
3147            session.peer.send(connection_id, update.clone())?;
3148        }
3149    }
3150
3151    Ok(())
3152}
3153
3154/// Move a channel to a new parent.
3155async fn move_channel(
3156    request: proto::MoveChannel,
3157    response: Response<proto::MoveChannel>,
3158    session: Session,
3159) -> Result<()> {
3160    let channel_id = ChannelId::from_proto(request.channel_id);
3161    let to = ChannelId::from_proto(request.to);
3162
3163    let (root_id, channels) = session
3164        .db()
3165        .await
3166        .move_channel(channel_id, to, session.user_id())
3167        .await?;
3168
3169    let connection_pool = session.connection_pool().await;
3170    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3171        let channels = channels
3172            .iter()
3173            .filter_map(|channel| {
3174                if role.can_see_channel(channel.visibility) {
3175                    Some(channel.to_proto())
3176                } else {
3177                    None
3178                }
3179            })
3180            .collect::<Vec<_>>();
3181        if channels.is_empty() {
3182            continue;
3183        }
3184
3185        let update = proto::UpdateChannels {
3186            channels,
3187            ..Default::default()
3188        };
3189
3190        session.peer.send(connection_id, update.clone())?;
3191    }
3192
3193    response.send(Ack {})?;
3194    Ok(())
3195}
3196
3197/// Get the list of channel members
3198async fn get_channel_members(
3199    request: proto::GetChannelMembers,
3200    response: Response<proto::GetChannelMembers>,
3201    session: Session,
3202) -> Result<()> {
3203    let db = session.db().await;
3204    let channel_id = ChannelId::from_proto(request.channel_id);
3205    let limit = if request.limit == 0 {
3206        u16::MAX as u64
3207    } else {
3208        request.limit
3209    };
3210    let (members, users) = db
3211        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3212        .await?;
3213    response.send(proto::GetChannelMembersResponse { members, users })?;
3214    Ok(())
3215}
3216
3217/// Accept or decline a channel invitation.
3218async fn respond_to_channel_invite(
3219    request: proto::RespondToChannelInvite,
3220    response: Response<proto::RespondToChannelInvite>,
3221    session: Session,
3222) -> Result<()> {
3223    let db = session.db().await;
3224    let channel_id = ChannelId::from_proto(request.channel_id);
3225    let RespondToChannelInvite {
3226        membership_update,
3227        notifications,
3228    } = db
3229        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3230        .await?;
3231
3232    let mut connection_pool = session.connection_pool().await;
3233    if let Some(membership_update) = membership_update {
3234        notify_membership_updated(
3235            &mut connection_pool,
3236            membership_update,
3237            session.user_id(),
3238            &session.peer,
3239        );
3240    } else {
3241        let update = proto::UpdateChannels {
3242            remove_channel_invitations: vec![channel_id.to_proto()],
3243            ..Default::default()
3244        };
3245
3246        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3247            session.peer.send(connection_id, update.clone())?;
3248        }
3249    };
3250
3251    send_notifications(&connection_pool, &session.peer, notifications);
3252
3253    response.send(proto::Ack {})?;
3254
3255    Ok(())
3256}
3257
3258/// Join the channels' room
3259async fn join_channel(
3260    request: proto::JoinChannel,
3261    response: Response<proto::JoinChannel>,
3262    session: Session,
3263) -> Result<()> {
3264    let channel_id = ChannelId::from_proto(request.channel_id);
3265    join_channel_internal(channel_id, Box::new(response), session).await
3266}
3267
3268trait JoinChannelInternalResponse {
3269    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3270}
3271impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3272    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3273        Response::<proto::JoinChannel>::send(self, result)
3274    }
3275}
3276impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3277    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3278        Response::<proto::JoinRoom>::send(self, result)
3279    }
3280}
3281
3282async fn join_channel_internal(
3283    channel_id: ChannelId,
3284    response: Box<impl JoinChannelInternalResponse>,
3285    session: Session,
3286) -> Result<()> {
3287    let joined_room = {
3288        let mut db = session.db().await;
3289        // If zed quits without leaving the room, and the user re-opens zed before the
3290        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3291        // room they were in.
3292        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3293            tracing::info!(
3294                stale_connection_id = %connection,
3295                "cleaning up stale connection",
3296            );
3297            drop(db);
3298            leave_room_for_session(&session, connection).await?;
3299            db = session.db().await;
3300        }
3301
3302        let (joined_room, membership_updated, role) = db
3303            .join_channel(channel_id, session.user_id(), session.connection_id)
3304            .await?;
3305
3306        let live_kit_connection_info =
3307            session
3308                .app_state
3309                .livekit_client
3310                .as_ref()
3311                .and_then(|live_kit| {
3312                    let (can_publish, token) = if role == ChannelRole::Guest {
3313                        (
3314                            false,
3315                            live_kit
3316                                .guest_token(
3317                                    &joined_room.room.livekit_room,
3318                                    &session.user_id().to_string(),
3319                                )
3320                                .trace_err()?,
3321                        )
3322                    } else {
3323                        (
3324                            true,
3325                            live_kit
3326                                .room_token(
3327                                    &joined_room.room.livekit_room,
3328                                    &session.user_id().to_string(),
3329                                )
3330                                .trace_err()?,
3331                        )
3332                    };
3333
3334                    Some(LiveKitConnectionInfo {
3335                        server_url: live_kit.url().into(),
3336                        token,
3337                        can_publish,
3338                    })
3339                });
3340
3341        response.send(proto::JoinRoomResponse {
3342            room: Some(joined_room.room.clone()),
3343            channel_id: joined_room
3344                .channel
3345                .as_ref()
3346                .map(|channel| channel.id.to_proto()),
3347            live_kit_connection_info,
3348        })?;
3349
3350        let mut connection_pool = session.connection_pool().await;
3351        if let Some(membership_updated) = membership_updated {
3352            notify_membership_updated(
3353                &mut connection_pool,
3354                membership_updated,
3355                session.user_id(),
3356                &session.peer,
3357            );
3358        }
3359
3360        room_updated(&joined_room.room, &session.peer);
3361
3362        joined_room
3363    };
3364
3365    channel_updated(
3366        &joined_room.channel.context("channel not returned")?,
3367        &joined_room.room,
3368        &session.peer,
3369        &*session.connection_pool().await,
3370    );
3371
3372    update_user_contacts(session.user_id(), &session).await?;
3373    Ok(())
3374}
3375
3376/// Start editing the channel notes
3377async fn join_channel_buffer(
3378    request: proto::JoinChannelBuffer,
3379    response: Response<proto::JoinChannelBuffer>,
3380    session: Session,
3381) -> Result<()> {
3382    let db = session.db().await;
3383    let channel_id = ChannelId::from_proto(request.channel_id);
3384
3385    let open_response = db
3386        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3387        .await?;
3388
3389    let collaborators = open_response.collaborators.clone();
3390    response.send(open_response)?;
3391
3392    let update = UpdateChannelBufferCollaborators {
3393        channel_id: channel_id.to_proto(),
3394        collaborators: collaborators.clone(),
3395    };
3396    channel_buffer_updated(
3397        session.connection_id,
3398        collaborators
3399            .iter()
3400            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3401        &update,
3402        &session.peer,
3403    );
3404
3405    Ok(())
3406}
3407
3408/// Edit the channel notes
3409async fn update_channel_buffer(
3410    request: proto::UpdateChannelBuffer,
3411    session: Session,
3412) -> Result<()> {
3413    let db = session.db().await;
3414    let channel_id = ChannelId::from_proto(request.channel_id);
3415
3416    let (collaborators, epoch, version) = db
3417        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3418        .await?;
3419
3420    channel_buffer_updated(
3421        session.connection_id,
3422        collaborators.clone(),
3423        &proto::UpdateChannelBuffer {
3424            channel_id: channel_id.to_proto(),
3425            operations: request.operations,
3426        },
3427        &session.peer,
3428    );
3429
3430    let pool = &*session.connection_pool().await;
3431
3432    let non_collaborators =
3433        pool.channel_connection_ids(channel_id)
3434            .filter_map(|(connection_id, _)| {
3435                if collaborators.contains(&connection_id) {
3436                    None
3437                } else {
3438                    Some(connection_id)
3439                }
3440            });
3441
3442    broadcast(None, non_collaborators, |peer_id| {
3443        session.peer.send(
3444            peer_id,
3445            proto::UpdateChannels {
3446                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3447                    channel_id: channel_id.to_proto(),
3448                    epoch: epoch as u64,
3449                    version: version.clone(),
3450                }],
3451                ..Default::default()
3452            },
3453        )
3454    });
3455
3456    Ok(())
3457}
3458
3459/// Rejoin the channel notes after a connection blip
3460async fn rejoin_channel_buffers(
3461    request: proto::RejoinChannelBuffers,
3462    response: Response<proto::RejoinChannelBuffers>,
3463    session: Session,
3464) -> Result<()> {
3465    let db = session.db().await;
3466    let buffers = db
3467        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3468        .await?;
3469
3470    for rejoined_buffer in &buffers {
3471        let collaborators_to_notify = rejoined_buffer
3472            .buffer
3473            .collaborators
3474            .iter()
3475            .filter_map(|c| Some(c.peer_id?.into()));
3476        channel_buffer_updated(
3477            session.connection_id,
3478            collaborators_to_notify,
3479            &proto::UpdateChannelBufferCollaborators {
3480                channel_id: rejoined_buffer.buffer.channel_id,
3481                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3482            },
3483            &session.peer,
3484        );
3485    }
3486
3487    response.send(proto::RejoinChannelBuffersResponse {
3488        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3489    })?;
3490
3491    Ok(())
3492}
3493
3494/// Stop editing the channel notes
3495async fn leave_channel_buffer(
3496    request: proto::LeaveChannelBuffer,
3497    response: Response<proto::LeaveChannelBuffer>,
3498    session: Session,
3499) -> Result<()> {
3500    let db = session.db().await;
3501    let channel_id = ChannelId::from_proto(request.channel_id);
3502
3503    let left_buffer = db
3504        .leave_channel_buffer(channel_id, session.connection_id)
3505        .await?;
3506
3507    response.send(Ack {})?;
3508
3509    channel_buffer_updated(
3510        session.connection_id,
3511        left_buffer.connections,
3512        &proto::UpdateChannelBufferCollaborators {
3513            channel_id: channel_id.to_proto(),
3514            collaborators: left_buffer.collaborators,
3515        },
3516        &session.peer,
3517    );
3518
3519    Ok(())
3520}
3521
3522fn channel_buffer_updated<T: EnvelopedMessage>(
3523    sender_id: ConnectionId,
3524    collaborators: impl IntoIterator<Item = ConnectionId>,
3525    message: &T,
3526    peer: &Peer,
3527) {
3528    broadcast(Some(sender_id), collaborators, |peer_id| {
3529        peer.send(peer_id, message.clone())
3530    });
3531}
3532
3533fn send_notifications(
3534    connection_pool: &ConnectionPool,
3535    peer: &Peer,
3536    notifications: db::NotificationBatch,
3537) {
3538    for (user_id, notification) in notifications {
3539        for connection_id in connection_pool.user_connection_ids(user_id) {
3540            if let Err(error) = peer.send(
3541                connection_id,
3542                proto::AddNotification {
3543                    notification: Some(notification.clone()),
3544                },
3545            ) {
3546                tracing::error!(
3547                    "failed to send notification to {:?} {}",
3548                    connection_id,
3549                    error
3550                );
3551            }
3552        }
3553    }
3554}
3555
3556/// Send a message to the channel
3557async fn send_channel_message(
3558    request: proto::SendChannelMessage,
3559    response: Response<proto::SendChannelMessage>,
3560    session: Session,
3561) -> Result<()> {
3562    // Validate the message body.
3563    let body = request.body.trim().to_string();
3564    if body.len() > MAX_MESSAGE_LEN {
3565        return Err(anyhow!("message is too long"))?;
3566    }
3567    if body.is_empty() {
3568        return Err(anyhow!("message can't be blank"))?;
3569    }
3570
3571    // TODO: adjust mentions if body is trimmed
3572
3573    let timestamp = OffsetDateTime::now_utc();
3574    let nonce = request.nonce.context("nonce can't be blank")?;
3575
3576    let channel_id = ChannelId::from_proto(request.channel_id);
3577    let CreatedChannelMessage {
3578        message_id,
3579        participant_connection_ids,
3580        notifications,
3581    } = session
3582        .db()
3583        .await
3584        .create_channel_message(
3585            channel_id,
3586            session.user_id(),
3587            &body,
3588            &request.mentions,
3589            timestamp,
3590            nonce.clone().into(),
3591            request.reply_to_message_id.map(MessageId::from_proto),
3592        )
3593        .await?;
3594
3595    let message = proto::ChannelMessage {
3596        sender_id: session.user_id().to_proto(),
3597        id: message_id.to_proto(),
3598        body,
3599        mentions: request.mentions,
3600        timestamp: timestamp.unix_timestamp() as u64,
3601        nonce: Some(nonce),
3602        reply_to_message_id: request.reply_to_message_id,
3603        edited_at: None,
3604    };
3605    broadcast(
3606        Some(session.connection_id),
3607        participant_connection_ids.clone(),
3608        |connection| {
3609            session.peer.send(
3610                connection,
3611                proto::ChannelMessageSent {
3612                    channel_id: channel_id.to_proto(),
3613                    message: Some(message.clone()),
3614                },
3615            )
3616        },
3617    );
3618    response.send(proto::SendChannelMessageResponse {
3619        message: Some(message),
3620    })?;
3621
3622    let pool = &*session.connection_pool().await;
3623    let non_participants =
3624        pool.channel_connection_ids(channel_id)
3625            .filter_map(|(connection_id, _)| {
3626                if participant_connection_ids.contains(&connection_id) {
3627                    None
3628                } else {
3629                    Some(connection_id)
3630                }
3631            });
3632    broadcast(None, non_participants, |peer_id| {
3633        session.peer.send(
3634            peer_id,
3635            proto::UpdateChannels {
3636                latest_channel_message_ids: vec![proto::ChannelMessageId {
3637                    channel_id: channel_id.to_proto(),
3638                    message_id: message_id.to_proto(),
3639                }],
3640                ..Default::default()
3641            },
3642        )
3643    });
3644    send_notifications(pool, &session.peer, notifications);
3645
3646    Ok(())
3647}
3648
3649/// Delete a channel message
3650async fn remove_channel_message(
3651    request: proto::RemoveChannelMessage,
3652    response: Response<proto::RemoveChannelMessage>,
3653    session: Session,
3654) -> Result<()> {
3655    let channel_id = ChannelId::from_proto(request.channel_id);
3656    let message_id = MessageId::from_proto(request.message_id);
3657    let (connection_ids, existing_notification_ids) = session
3658        .db()
3659        .await
3660        .remove_channel_message(channel_id, message_id, session.user_id())
3661        .await?;
3662
3663    broadcast(
3664        Some(session.connection_id),
3665        connection_ids,
3666        move |connection| {
3667            session.peer.send(connection, request.clone())?;
3668
3669            for notification_id in &existing_notification_ids {
3670                session.peer.send(
3671                    connection,
3672                    proto::DeleteNotification {
3673                        notification_id: (*notification_id).to_proto(),
3674                    },
3675                )?;
3676            }
3677
3678            Ok(())
3679        },
3680    );
3681    response.send(proto::Ack {})?;
3682    Ok(())
3683}
3684
3685async fn update_channel_message(
3686    request: proto::UpdateChannelMessage,
3687    response: Response<proto::UpdateChannelMessage>,
3688    session: Session,
3689) -> Result<()> {
3690    let channel_id = ChannelId::from_proto(request.channel_id);
3691    let message_id = MessageId::from_proto(request.message_id);
3692    let updated_at = OffsetDateTime::now_utc();
3693    let UpdatedChannelMessage {
3694        message_id,
3695        participant_connection_ids,
3696        notifications,
3697        reply_to_message_id,
3698        timestamp,
3699        deleted_mention_notification_ids,
3700        updated_mention_notifications,
3701    } = session
3702        .db()
3703        .await
3704        .update_channel_message(
3705            channel_id,
3706            message_id,
3707            session.user_id(),
3708            request.body.as_str(),
3709            &request.mentions,
3710            updated_at,
3711        )
3712        .await?;
3713
3714    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3715
3716    let message = proto::ChannelMessage {
3717        sender_id: session.user_id().to_proto(),
3718        id: message_id.to_proto(),
3719        body: request.body.clone(),
3720        mentions: request.mentions.clone(),
3721        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3722        nonce: Some(nonce),
3723        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3724        edited_at: Some(updated_at.unix_timestamp() as u64),
3725    };
3726
3727    response.send(proto::Ack {})?;
3728
3729    let pool = &*session.connection_pool().await;
3730    broadcast(
3731        Some(session.connection_id),
3732        participant_connection_ids,
3733        |connection| {
3734            session.peer.send(
3735                connection,
3736                proto::ChannelMessageUpdate {
3737                    channel_id: channel_id.to_proto(),
3738                    message: Some(message.clone()),
3739                },
3740            )?;
3741
3742            for notification_id in &deleted_mention_notification_ids {
3743                session.peer.send(
3744                    connection,
3745                    proto::DeleteNotification {
3746                        notification_id: (*notification_id).to_proto(),
3747                    },
3748                )?;
3749            }
3750
3751            for notification in &updated_mention_notifications {
3752                session.peer.send(
3753                    connection,
3754                    proto::UpdateNotification {
3755                        notification: Some(notification.clone()),
3756                    },
3757                )?;
3758            }
3759
3760            Ok(())
3761        },
3762    );
3763
3764    send_notifications(pool, &session.peer, notifications);
3765
3766    Ok(())
3767}
3768
3769/// Mark a channel message as read
3770async fn acknowledge_channel_message(
3771    request: proto::AckChannelMessage,
3772    session: Session,
3773) -> Result<()> {
3774    let channel_id = ChannelId::from_proto(request.channel_id);
3775    let message_id = MessageId::from_proto(request.message_id);
3776    let notifications = session
3777        .db()
3778        .await
3779        .observe_channel_message(channel_id, session.user_id(), message_id)
3780        .await?;
3781    send_notifications(
3782        &*session.connection_pool().await,
3783        &session.peer,
3784        notifications,
3785    );
3786    Ok(())
3787}
3788
3789/// Mark a buffer version as synced
3790async fn acknowledge_buffer_version(
3791    request: proto::AckBufferOperation,
3792    session: Session,
3793) -> Result<()> {
3794    let buffer_id = BufferId::from_proto(request.buffer_id);
3795    session
3796        .db()
3797        .await
3798        .observe_buffer_version(
3799            buffer_id,
3800            session.user_id(),
3801            request.epoch as i32,
3802            &request.version,
3803        )
3804        .await?;
3805    Ok(())
3806}
3807
3808/// Get a Supermaven API key for the user
3809async fn get_supermaven_api_key(
3810    _request: proto::GetSupermavenApiKey,
3811    response: Response<proto::GetSupermavenApiKey>,
3812    session: Session,
3813) -> Result<()> {
3814    let user_id: String = session.user_id().to_string();
3815    if !session.is_staff() {
3816        return Err(anyhow!("supermaven not enabled for this account"))?;
3817    }
3818
3819    let email = session.email().context("user must have an email")?;
3820
3821    let supermaven_admin_api = session
3822        .supermaven_client
3823        .as_ref()
3824        .context("supermaven not configured")?;
3825
3826    let result = supermaven_admin_api
3827        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3828        .await?;
3829
3830    response.send(proto::GetSupermavenApiKeyResponse {
3831        api_key: result.api_key,
3832    })?;
3833
3834    Ok(())
3835}
3836
3837/// Start receiving chat updates for a channel
3838async fn join_channel_chat(
3839    request: proto::JoinChannelChat,
3840    response: Response<proto::JoinChannelChat>,
3841    session: Session,
3842) -> Result<()> {
3843    let channel_id = ChannelId::from_proto(request.channel_id);
3844
3845    let db = session.db().await;
3846    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3847        .await?;
3848    let messages = db
3849        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3850        .await?;
3851    response.send(proto::JoinChannelChatResponse {
3852        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3853        messages,
3854    })?;
3855    Ok(())
3856}
3857
3858/// Stop receiving chat updates for a channel
3859async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3860    let channel_id = ChannelId::from_proto(request.channel_id);
3861    session
3862        .db()
3863        .await
3864        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3865        .await?;
3866    Ok(())
3867}
3868
3869/// Retrieve the chat history for a channel
3870async fn get_channel_messages(
3871    request: proto::GetChannelMessages,
3872    response: Response<proto::GetChannelMessages>,
3873    session: Session,
3874) -> Result<()> {
3875    let channel_id = ChannelId::from_proto(request.channel_id);
3876    let messages = session
3877        .db()
3878        .await
3879        .get_channel_messages(
3880            channel_id,
3881            session.user_id(),
3882            MESSAGE_COUNT_PER_PAGE,
3883            Some(MessageId::from_proto(request.before_message_id)),
3884        )
3885        .await?;
3886    response.send(proto::GetChannelMessagesResponse {
3887        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888        messages,
3889    })?;
3890    Ok(())
3891}
3892
3893/// Retrieve specific chat messages
3894async fn get_channel_messages_by_id(
3895    request: proto::GetChannelMessagesById,
3896    response: Response<proto::GetChannelMessagesById>,
3897    session: Session,
3898) -> Result<()> {
3899    let message_ids = request
3900        .message_ids
3901        .iter()
3902        .map(|id| MessageId::from_proto(*id))
3903        .collect::<Vec<_>>();
3904    let messages = session
3905        .db()
3906        .await
3907        .get_channel_messages_by_id(session.user_id(), &message_ids)
3908        .await?;
3909    response.send(proto::GetChannelMessagesResponse {
3910        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3911        messages,
3912    })?;
3913    Ok(())
3914}
3915
3916/// Retrieve the current users notifications
3917async fn get_notifications(
3918    request: proto::GetNotifications,
3919    response: Response<proto::GetNotifications>,
3920    session: Session,
3921) -> Result<()> {
3922    let notifications = session
3923        .db()
3924        .await
3925        .get_notifications(
3926            session.user_id(),
3927            NOTIFICATION_COUNT_PER_PAGE,
3928            request.before_id.map(db::NotificationId::from_proto),
3929        )
3930        .await?;
3931    response.send(proto::GetNotificationsResponse {
3932        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3933        notifications,
3934    })?;
3935    Ok(())
3936}
3937
3938/// Mark notifications as read
3939async fn mark_notification_as_read(
3940    request: proto::MarkNotificationRead,
3941    response: Response<proto::MarkNotificationRead>,
3942    session: Session,
3943) -> Result<()> {
3944    let database = &session.db().await;
3945    let notifications = database
3946        .mark_notification_as_read_by_id(
3947            session.user_id(),
3948            NotificationId::from_proto(request.notification_id),
3949        )
3950        .await?;
3951    send_notifications(
3952        &*session.connection_pool().await,
3953        &session.peer,
3954        notifications,
3955    );
3956    response.send(proto::Ack {})?;
3957    Ok(())
3958}
3959
3960/// Get the current users information
3961async fn get_private_user_info(
3962    _request: proto::GetPrivateUserInfo,
3963    response: Response<proto::GetPrivateUserInfo>,
3964    session: Session,
3965) -> Result<()> {
3966    let db = session.db().await;
3967
3968    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3969    let user = db
3970        .get_user_by_id(session.user_id())
3971        .await?
3972        .context("user not found")?;
3973    let flags = db.get_user_flags(session.user_id()).await?;
3974
3975    response.send(proto::GetPrivateUserInfoResponse {
3976        metrics_id,
3977        staff: user.admin,
3978        flags,
3979        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3980    })?;
3981    Ok(())
3982}
3983
3984/// Accept the terms of service (tos) on behalf of the current user
3985async fn accept_terms_of_service(
3986    _request: proto::AcceptTermsOfService,
3987    response: Response<proto::AcceptTermsOfService>,
3988    session: Session,
3989) -> Result<()> {
3990    let db = session.db().await;
3991
3992    let accepted_tos_at = Utc::now();
3993    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3994        .await?;
3995
3996    response.send(proto::AcceptTermsOfServiceResponse {
3997        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3998    })?;
3999    Ok(())
4000}
4001
4002/// The minimum account age an account must have in order to use the LLM service.
4003pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4004
4005async fn get_llm_api_token(
4006    _request: proto::GetLlmToken,
4007    response: Response<proto::GetLlmToken>,
4008    session: Session,
4009) -> Result<()> {
4010    let db = session.db().await;
4011
4012    let flags = db.get_user_flags(session.user_id()).await?;
4013
4014    let user_id = session.user_id();
4015    let user = db
4016        .get_user_by_id(user_id)
4017        .await?
4018        .with_context(|| format!("user {user_id} not found"))?;
4019
4020    if user.accepted_tos_at.is_none() {
4021        Err(anyhow!("terms of service not accepted"))?
4022    }
4023
4024    let stripe_client = session
4025        .app_state
4026        .stripe_client
4027        .as_ref()
4028        .context("failed to retrieve Stripe client")?;
4029
4030    let stripe_billing = session
4031        .app_state
4032        .stripe_billing
4033        .as_ref()
4034        .context("failed to retrieve Stripe billing object")?;
4035
4036    let billing_customer =
4037        if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
4038            billing_customer
4039        } else {
4040            let customer_id = stripe_billing
4041                .find_or_create_customer_by_email(user.email_address.as_deref())
4042                .await?;
4043
4044            find_or_create_billing_customer(
4045                &session.app_state,
4046                &stripe_client,
4047                stripe::Expandable::Id(customer_id),
4048            )
4049            .await?
4050            .context("billing customer not found")?
4051        };
4052
4053    let billing_subscription =
4054        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4055            billing_subscription
4056        } else {
4057            let stripe_customer_id = billing_customer
4058                .stripe_customer_id
4059                .parse::<stripe::CustomerId>()
4060                .context("failed to parse Stripe customer ID from database")?;
4061
4062            let stripe_subscription = stripe_billing
4063                .subscribe_to_zed_free(stripe_customer_id)
4064                .await?;
4065
4066            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4067                billing_customer_id: billing_customer.id,
4068                kind: Some(SubscriptionKind::ZedFree),
4069                stripe_subscription_id: stripe_subscription.id.to_string(),
4070                stripe_subscription_status: stripe_subscription.status.into(),
4071                stripe_cancellation_reason: None,
4072                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4073                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4074            })
4075            .await?
4076        };
4077
4078    let billing_preferences = db.get_billing_preferences(user.id).await?;
4079
4080    let token = LlmTokenClaims::create(
4081        &user,
4082        session.is_staff(),
4083        billing_customer,
4084        billing_preferences,
4085        &flags,
4086        billing_subscription,
4087        session.system_id.clone(),
4088        &session.app_state.config,
4089    )?;
4090    response.send(proto::GetLlmTokenResponse { token })?;
4091    Ok(())
4092}
4093
4094fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4095    let message = match message {
4096        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4097        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4098        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4099        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4100        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4101            code: frame.code.into(),
4102            reason: frame.reason.as_str().to_owned().into(),
4103        })),
4104        // We should never receive a frame while reading the message, according
4105        // to the `tungstenite` maintainers:
4106        //
4107        // > It cannot occur when you read messages from the WebSocket, but it
4108        // > can be used when you want to send the raw frames (e.g. you want to
4109        // > send the frames to the WebSocket without composing the full message first).
4110        // >
4111        // > — https://github.com/snapview/tungstenite-rs/issues/268
4112        TungsteniteMessage::Frame(_) => {
4113            bail!("received an unexpected frame while reading the message")
4114        }
4115    };
4116
4117    Ok(message)
4118}
4119
4120fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4121    match message {
4122        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4123        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4124        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4125        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4126        AxumMessage::Close(frame) => {
4127            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4128                code: frame.code.into(),
4129                reason: frame.reason.as_ref().into(),
4130            }))
4131        }
4132    }
4133}
4134
4135fn notify_membership_updated(
4136    connection_pool: &mut ConnectionPool,
4137    result: MembershipUpdated,
4138    user_id: UserId,
4139    peer: &Peer,
4140) {
4141    for membership in &result.new_channels.channel_memberships {
4142        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4143    }
4144    for channel_id in &result.removed_channels {
4145        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4146    }
4147
4148    let user_channels_update = proto::UpdateUserChannels {
4149        channel_memberships: result
4150            .new_channels
4151            .channel_memberships
4152            .iter()
4153            .map(|cm| proto::ChannelMembership {
4154                channel_id: cm.channel_id.to_proto(),
4155                role: cm.role.into(),
4156            })
4157            .collect(),
4158        ..Default::default()
4159    };
4160
4161    let mut update = build_channels_update(result.new_channels);
4162    update.delete_channels = result
4163        .removed_channels
4164        .into_iter()
4165        .map(|id| id.to_proto())
4166        .collect();
4167    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4168
4169    for connection_id in connection_pool.user_connection_ids(user_id) {
4170        peer.send(connection_id, user_channels_update.clone())
4171            .trace_err();
4172        peer.send(connection_id, update.clone()).trace_err();
4173    }
4174}
4175
4176fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4177    proto::UpdateUserChannels {
4178        channel_memberships: channels
4179            .channel_memberships
4180            .iter()
4181            .map(|m| proto::ChannelMembership {
4182                channel_id: m.channel_id.to_proto(),
4183                role: m.role.into(),
4184            })
4185            .collect(),
4186        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4187        observed_channel_message_id: channels.observed_channel_messages.clone(),
4188    }
4189}
4190
4191fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4192    let mut update = proto::UpdateChannels::default();
4193
4194    for channel in channels.channels {
4195        update.channels.push(channel.to_proto());
4196    }
4197
4198    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4199    update.latest_channel_message_ids = channels.latest_channel_messages;
4200
4201    for (channel_id, participants) in channels.channel_participants {
4202        update
4203            .channel_participants
4204            .push(proto::ChannelParticipants {
4205                channel_id: channel_id.to_proto(),
4206                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4207            });
4208    }
4209
4210    for channel in channels.invited_channels {
4211        update.channel_invitations.push(channel.to_proto());
4212    }
4213
4214    update
4215}
4216
4217fn build_initial_contacts_update(
4218    contacts: Vec<db::Contact>,
4219    pool: &ConnectionPool,
4220) -> proto::UpdateContacts {
4221    let mut update = proto::UpdateContacts::default();
4222
4223    for contact in contacts {
4224        match contact {
4225            db::Contact::Accepted { user_id, busy } => {
4226                update.contacts.push(contact_for_user(user_id, busy, pool));
4227            }
4228            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4229            db::Contact::Incoming { user_id } => {
4230                update
4231                    .incoming_requests
4232                    .push(proto::IncomingContactRequest {
4233                        requester_id: user_id.to_proto(),
4234                    })
4235            }
4236        }
4237    }
4238
4239    update
4240}
4241
4242fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4243    proto::Contact {
4244        user_id: user_id.to_proto(),
4245        online: pool.is_user_online(user_id),
4246        busy,
4247    }
4248}
4249
4250fn room_updated(room: &proto::Room, peer: &Peer) {
4251    broadcast(
4252        None,
4253        room.participants
4254            .iter()
4255            .filter_map(|participant| Some(participant.peer_id?.into())),
4256        |peer_id| {
4257            peer.send(
4258                peer_id,
4259                proto::RoomUpdated {
4260                    room: Some(room.clone()),
4261                },
4262            )
4263        },
4264    );
4265}
4266
4267fn channel_updated(
4268    channel: &db::channel::Model,
4269    room: &proto::Room,
4270    peer: &Peer,
4271    pool: &ConnectionPool,
4272) {
4273    let participants = room
4274        .participants
4275        .iter()
4276        .map(|p| p.user_id)
4277        .collect::<Vec<_>>();
4278
4279    broadcast(
4280        None,
4281        pool.channel_connection_ids(channel.root_id())
4282            .filter_map(|(channel_id, role)| {
4283                role.can_see_channel(channel.visibility)
4284                    .then_some(channel_id)
4285            }),
4286        |peer_id| {
4287            peer.send(
4288                peer_id,
4289                proto::UpdateChannels {
4290                    channel_participants: vec![proto::ChannelParticipants {
4291                        channel_id: channel.id.to_proto(),
4292                        participant_user_ids: participants.clone(),
4293                    }],
4294                    ..Default::default()
4295                },
4296            )
4297        },
4298    );
4299}
4300
4301async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4302    let db = session.db().await;
4303
4304    let contacts = db.get_contacts(user_id).await?;
4305    let busy = db.is_user_busy(user_id).await?;
4306
4307    let pool = session.connection_pool().await;
4308    let updated_contact = contact_for_user(user_id, busy, &pool);
4309    for contact in contacts {
4310        if let db::Contact::Accepted {
4311            user_id: contact_user_id,
4312            ..
4313        } = contact
4314        {
4315            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4316                session
4317                    .peer
4318                    .send(
4319                        contact_conn_id,
4320                        proto::UpdateContacts {
4321                            contacts: vec![updated_contact.clone()],
4322                            remove_contacts: Default::default(),
4323                            incoming_requests: Default::default(),
4324                            remove_incoming_requests: Default::default(),
4325                            outgoing_requests: Default::default(),
4326                            remove_outgoing_requests: Default::default(),
4327                        },
4328                    )
4329                    .trace_err();
4330            }
4331        }
4332    }
4333    Ok(())
4334}
4335
4336async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4337    let mut contacts_to_update = HashSet::default();
4338
4339    let room_id;
4340    let canceled_calls_to_user_ids;
4341    let livekit_room;
4342    let delete_livekit_room;
4343    let room;
4344    let channel;
4345
4346    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4347        contacts_to_update.insert(session.user_id());
4348
4349        for project in left_room.left_projects.values() {
4350            project_left(project, session);
4351        }
4352
4353        room_id = RoomId::from_proto(left_room.room.id);
4354        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4355        livekit_room = mem::take(&mut left_room.room.livekit_room);
4356        delete_livekit_room = left_room.deleted;
4357        room = mem::take(&mut left_room.room);
4358        channel = mem::take(&mut left_room.channel);
4359
4360        room_updated(&room, &session.peer);
4361    } else {
4362        return Ok(());
4363    }
4364
4365    if let Some(channel) = channel {
4366        channel_updated(
4367            &channel,
4368            &room,
4369            &session.peer,
4370            &*session.connection_pool().await,
4371        );
4372    }
4373
4374    {
4375        let pool = session.connection_pool().await;
4376        for canceled_user_id in canceled_calls_to_user_ids {
4377            for connection_id in pool.user_connection_ids(canceled_user_id) {
4378                session
4379                    .peer
4380                    .send(
4381                        connection_id,
4382                        proto::CallCanceled {
4383                            room_id: room_id.to_proto(),
4384                        },
4385                    )
4386                    .trace_err();
4387            }
4388            contacts_to_update.insert(canceled_user_id);
4389        }
4390    }
4391
4392    for contact_user_id in contacts_to_update {
4393        update_user_contacts(contact_user_id, session).await?;
4394    }
4395
4396    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4397        live_kit
4398            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4399            .await
4400            .trace_err();
4401
4402        if delete_livekit_room {
4403            live_kit.delete_room(livekit_room).await.trace_err();
4404        }
4405    }
4406
4407    Ok(())
4408}
4409
4410async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4411    let left_channel_buffers = session
4412        .db()
4413        .await
4414        .leave_channel_buffers(session.connection_id)
4415        .await?;
4416
4417    for left_buffer in left_channel_buffers {
4418        channel_buffer_updated(
4419            session.connection_id,
4420            left_buffer.connections,
4421            &proto::UpdateChannelBufferCollaborators {
4422                channel_id: left_buffer.channel_id.to_proto(),
4423                collaborators: left_buffer.collaborators,
4424            },
4425            &session.peer,
4426        );
4427    }
4428
4429    Ok(())
4430}
4431
4432fn project_left(project: &db::LeftProject, session: &Session) {
4433    for connection_id in &project.connection_ids {
4434        if project.should_unshare {
4435            session
4436                .peer
4437                .send(
4438                    *connection_id,
4439                    proto::UnshareProject {
4440                        project_id: project.id.to_proto(),
4441                    },
4442                )
4443                .trace_err();
4444        } else {
4445            session
4446                .peer
4447                .send(
4448                    *connection_id,
4449                    proto::RemoveProjectCollaborator {
4450                        project_id: project.id.to_proto(),
4451                        peer_id: Some(session.connection_id.into()),
4452                    },
4453                )
4454                .trace_err();
4455        }
4456    }
4457}
4458
4459pub trait ResultExt {
4460    type Ok;
4461
4462    fn trace_err(self) -> Option<Self::Ok>;
4463}
4464
4465impl<T, E> ResultExt for Result<T, E>
4466where
4467    E: std::fmt::Debug,
4468{
4469    type Ok = T;
4470
4471    #[track_caller]
4472    fn trace_err(self) -> Option<T> {
4473        match self {
4474            Ok(value) => Some(value),
4475            Err(error) => {
4476                tracing::error!("{:?}", error);
4477                None
4478            }
4479        }
4480    }
4481}