rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::stripe_client::StripeCustomerId;
   9use crate::{
  10    AppState, Error, Result, auth,
  11    db::{
  12        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  13        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  14        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  15        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  16    },
  17    executor::Executor,
  18};
  19use anyhow::{Context as _, anyhow, bail};
  20use async_tungstenite::tungstenite::{
  21    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  22};
  23use axum::{
  24    Extension, Router, TypedHeader,
  25    body::Body,
  26    extract::{
  27        ConnectInfo, WebSocketUpgrade,
  28        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  29    },
  30    headers::{Header, HeaderName},
  31    http::StatusCode,
  32    middleware,
  33    response::IntoResponse,
  34    routing::get,
  35};
  36use chrono::Utc;
  37use collections::{HashMap, HashSet};
  38pub use connection_pool::{ConnectionPool, ZedVersion};
  39use core::fmt::{self, Debug, Formatter};
  40use reqwest_client::ReqwestClient;
  41use rpc::proto::split_repository_update;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  46    stream::FuturesUnordered,
  47};
  48use prometheus::{IntGauge, register_int_gauge};
  49use rpc::{
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        Arc, OnceLock,
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{Semaphore, watch};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    Instrument,
  77    field::{self},
  78    info_span, instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111}
 112
 113impl Principal {
 114    fn user(&self) -> &User {
 115        match self {
 116            Principal::User(user) => user,
 117            Principal::Impersonated { user, .. } => user,
 118        }
 119    }
 120
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132        }
 133    }
 134}
 135
 136#[derive(Clone)]
 137struct Session {
 138    principal: Principal,
 139    connection_id: ConnectionId,
 140    db: Arc<tokio::sync::Mutex<DbHandle>>,
 141    peer: Arc<Peer>,
 142    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    system_id: Option<String>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn is_staff(&self) -> bool {
 173        match &self.principal {
 174            Principal::User(user) => user.admin,
 175            Principal::Impersonated { .. } => true,
 176        }
 177    }
 178
 179    fn user_id(&self) -> UserId {
 180        match &self.principal {
 181            Principal::User(user) => user.id,
 182            Principal::Impersonated { user, .. } => user.id,
 183        }
 184    }
 185
 186    pub fn email(&self) -> Option<String> {
 187        match &self.principal {
 188            Principal::User(user) => user.email_address.clone(),
 189            Principal::Impersonated { user, .. } => user.email_address.clone(),
 190        }
 191    }
 192}
 193
 194impl Debug for Session {
 195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 196        let mut result = f.debug_struct("Session");
 197        match &self.principal {
 198            Principal::User(user) => {
 199                result.field("user", &user.github_login);
 200            }
 201            Principal::Impersonated { user, admin } => {
 202                result.field("user", &user.github_login);
 203                result.field("impersonator", &admin.github_login);
 204            }
 205        }
 206        result.field("connection_id", &self.connection_id).finish()
 207    }
 208}
 209
 210struct DbHandle(Arc<Database>);
 211
 212impl Deref for DbHandle {
 213    type Target = Database;
 214
 215    fn deref(&self) -> &Self::Target {
 216        self.0.as_ref()
 217    }
 218}
 219
 220pub struct Server {
 221    id: parking_lot::Mutex<ServerId>,
 222    peer: Arc<Peer>,
 223    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 224    app_state: Arc<AppState>,
 225    handlers: HashMap<TypeId, MessageHandler>,
 226    teardown: watch::Sender<bool>,
 227}
 228
 229pub(crate) struct ConnectionPoolGuard<'a> {
 230    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 231    _not_send: PhantomData<Rc<()>>,
 232}
 233
 234#[derive(Serialize)]
 235pub struct ServerSnapshot<'a> {
 236    peer: &'a Peer,
 237    #[serde(serialize_with = "serialize_deref")]
 238    connection_pool: ConnectionPoolGuard<'a>,
 239}
 240
 241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 242where
 243    S: Serializer,
 244    T: Deref<Target = U>,
 245    U: Serialize,
 246{
 247    Serialize::serialize(value.deref(), serializer)
 248}
 249
 250impl Server {
 251    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 252        let mut server = Self {
 253            id: parking_lot::Mutex::new(id),
 254            peer: Peer::new(id.0 as u32),
 255            app_state: app_state.clone(),
 256            connection_pool: Default::default(),
 257            handlers: Default::default(),
 258            teardown: watch::channel(false).0,
 259        };
 260
 261        server
 262            .add_request_handler(ping)
 263            .add_request_handler(create_room)
 264            .add_request_handler(join_room)
 265            .add_request_handler(rejoin_room)
 266            .add_request_handler(leave_room)
 267            .add_request_handler(set_room_participant_role)
 268            .add_request_handler(call)
 269            .add_request_handler(cancel_call)
 270            .add_message_handler(decline_call)
 271            .add_request_handler(update_participant_location)
 272            .add_request_handler(share_project)
 273            .add_message_handler(unshare_project)
 274            .add_request_handler(join_project)
 275            .add_message_handler(leave_project)
 276            .add_request_handler(update_project)
 277            .add_request_handler(update_worktree)
 278            .add_request_handler(update_repository)
 279            .add_request_handler(remove_repository)
 280            .add_message_handler(start_language_server)
 281            .add_message_handler(update_language_server)
 282            .add_message_handler(update_diagnostic_summary)
 283            .add_message_handler(update_worktree_settings)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 288            .add_request_handler(forward_find_search_candidates_request)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 294            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 295            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 296            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 297            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 298            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 304            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 305            .add_request_handler(
 306                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 307            )
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 312            .add_request_handler(
 313                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 314            )
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 349            .add_message_handler(create_buffer_for_peer)
 350            .add_request_handler(update_buffer)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 357            .add_request_handler(get_users)
 358            .add_request_handler(fuzzy_search_users)
 359            .add_request_handler(request_contact)
 360            .add_request_handler(remove_contact)
 361            .add_request_handler(respond_to_contact_request)
 362            .add_message_handler(subscribe_to_channels)
 363            .add_request_handler(create_channel)
 364            .add_request_handler(delete_channel)
 365            .add_request_handler(invite_channel_member)
 366            .add_request_handler(remove_channel_member)
 367            .add_request_handler(set_channel_member_role)
 368            .add_request_handler(set_channel_visibility)
 369            .add_request_handler(rename_channel)
 370            .add_request_handler(join_channel_buffer)
 371            .add_request_handler(leave_channel_buffer)
 372            .add_message_handler(update_channel_buffer)
 373            .add_request_handler(rejoin_channel_buffers)
 374            .add_request_handler(get_channel_members)
 375            .add_request_handler(respond_to_channel_invite)
 376            .add_request_handler(join_channel)
 377            .add_request_handler(join_channel_chat)
 378            .add_message_handler(leave_channel_chat)
 379            .add_request_handler(send_channel_message)
 380            .add_request_handler(remove_channel_message)
 381            .add_request_handler(update_channel_message)
 382            .add_request_handler(get_channel_messages)
 383            .add_request_handler(get_channel_messages_by_id)
 384            .add_request_handler(get_notifications)
 385            .add_request_handler(mark_notification_as_read)
 386            .add_request_handler(move_channel)
 387            .add_request_handler(follow)
 388            .add_message_handler(unfollow)
 389            .add_message_handler(update_followers)
 390            .add_request_handler(get_private_user_info)
 391            .add_request_handler(get_llm_api_token)
 392            .add_request_handler(accept_terms_of_service)
 393            .add_message_handler(acknowledge_channel_message)
 394            .add_message_handler(acknowledge_buffer_version)
 395            .add_request_handler(get_supermaven_api_key)
 396            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 398            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 402            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 405            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 408            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 409            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 411            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 412            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 415            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 416            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 417            .add_message_handler(update_context);
 418
 419        Arc::new(server)
 420    }
 421
 422    pub async fn start(&self) -> Result<()> {
 423        let server_id = *self.id.lock();
 424        let app_state = self.app_state.clone();
 425        let peer = self.peer.clone();
 426        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 427        let pool = self.connection_pool.clone();
 428        let livekit_client = self.app_state.livekit_client.clone();
 429
 430        let span = info_span!("start server");
 431        self.app_state.executor.spawn_detached(
 432            async move {
 433                tracing::info!("waiting for cleanup timeout");
 434                timeout.await;
 435                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 436                if let Some((room_ids, channel_ids)) = app_state
 437                    .db
 438                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err()
 441                {
 442                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 443                    tracing::info!(
 444                        stale_channel_buffer_count = channel_ids.len(),
 445                        "retrieved stale channel buffers"
 446                    );
 447
 448                    for channel_id in channel_ids {
 449                        if let Some(refreshed_channel_buffer) = app_state
 450                            .db
 451                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 452                            .await
 453                            .trace_err()
 454                        {
 455                            for connection_id in refreshed_channel_buffer.connection_ids {
 456                                peer.send(
 457                                    connection_id,
 458                                    proto::UpdateChannelBufferCollaborators {
 459                                        channel_id: channel_id.to_proto(),
 460                                        collaborators: refreshed_channel_buffer
 461                                            .collaborators
 462                                            .clone(),
 463                                    },
 464                                )
 465                                .trace_err();
 466                            }
 467                        }
 468                    }
 469
 470                    for room_id in room_ids {
 471                        let mut contacts_to_update = HashSet::default();
 472                        let mut canceled_calls_to_user_ids = Vec::new();
 473                        let mut livekit_room = String::new();
 474                        let mut delete_livekit_room = false;
 475
 476                        if let Some(mut refreshed_room) = app_state
 477                            .db
 478                            .clear_stale_room_participants(room_id, server_id)
 479                            .await
 480                            .trace_err()
 481                        {
 482                            tracing::info!(
 483                                room_id = room_id.0,
 484                                new_participant_count = refreshed_room.room.participants.len(),
 485                                "refreshed room"
 486                            );
 487                            room_updated(&refreshed_room.room, &peer);
 488                            if let Some(channel) = refreshed_room.channel.as_ref() {
 489                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 490                            }
 491                            contacts_to_update
 492                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 493                            contacts_to_update
 494                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 495                            canceled_calls_to_user_ids =
 496                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 497                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 498                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 499                        }
 500
 501                        {
 502                            let pool = pool.lock();
 503                            for canceled_user_id in canceled_calls_to_user_ids {
 504                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 505                                    peer.send(
 506                                        connection_id,
 507                                        proto::CallCanceled {
 508                                            room_id: room_id.to_proto(),
 509                                        },
 510                                    )
 511                                    .trace_err();
 512                                }
 513                            }
 514                        }
 515
 516                        for user_id in contacts_to_update {
 517                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 518                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 519                            if let Some((busy, contacts)) = busy.zip(contacts) {
 520                                let pool = pool.lock();
 521                                let updated_contact = contact_for_user(user_id, busy, &pool);
 522                                for contact in contacts {
 523                                    if let db::Contact::Accepted {
 524                                        user_id: contact_user_id,
 525                                        ..
 526                                    } = contact
 527                                    {
 528                                        for contact_conn_id in
 529                                            pool.user_connection_ids(contact_user_id)
 530                                        {
 531                                            peer.send(
 532                                                contact_conn_id,
 533                                                proto::UpdateContacts {
 534                                                    contacts: vec![updated_contact.clone()],
 535                                                    remove_contacts: Default::default(),
 536                                                    incoming_requests: Default::default(),
 537                                                    remove_incoming_requests: Default::default(),
 538                                                    outgoing_requests: Default::default(),
 539                                                    remove_outgoing_requests: Default::default(),
 540                                                },
 541                                            )
 542                                            .trace_err();
 543                                        }
 544                                    }
 545                                }
 546                            }
 547                        }
 548
 549                        if let Some(live_kit) = livekit_client.as_ref() {
 550                            if delete_livekit_room {
 551                                live_kit.delete_room(livekit_room).await.trace_err();
 552                            }
 553                        }
 554                    }
 555                }
 556
 557                app_state
 558                    .db
 559                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 560                    .await
 561                    .trace_err();
 562            }
 563            .instrument(span),
 564        );
 565        Ok(())
 566    }
 567
 568    pub fn teardown(&self) {
 569        self.peer.teardown();
 570        self.connection_pool.lock().reset();
 571        let _ = self.teardown.send(true);
 572    }
 573
 574    #[cfg(test)]
 575    pub fn reset(&self, id: ServerId) {
 576        self.teardown();
 577        *self.id.lock() = id;
 578        self.peer.reset(id.0 as u32);
 579        let _ = self.teardown.send(false);
 580    }
 581
 582    #[cfg(test)]
 583    pub fn id(&self) -> ServerId {
 584        *self.id.lock()
 585    }
 586
 587    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 588    where
 589        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 590        Fut: 'static + Send + Future<Output = Result<()>>,
 591        M: EnvelopedMessage,
 592    {
 593        let prev_handler = self.handlers.insert(
 594            TypeId::of::<M>(),
 595            Box::new(move |envelope, session| {
 596                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 597                let received_at = envelope.received_at;
 598                tracing::info!("message received");
 599                let start_time = Instant::now();
 600                let future = (handler)(*envelope, session);
 601                async move {
 602                    let result = future.await;
 603                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 604                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 605                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 606                    let payload_type = M::NAME;
 607
 608                    match result {
 609                        Err(error) => {
 610                            tracing::error!(
 611                                ?error,
 612                                total_duration_ms,
 613                                processing_duration_ms,
 614                                queue_duration_ms,
 615                                payload_type,
 616                                "error handling message"
 617                            )
 618                        }
 619                        Ok(()) => tracing::info!(
 620                            total_duration_ms,
 621                            processing_duration_ms,
 622                            queue_duration_ms,
 623                            "finished handling message"
 624                        ),
 625                    }
 626                }
 627                .boxed()
 628            }),
 629        );
 630        if prev_handler.is_some() {
 631            panic!("registered a handler for the same message twice");
 632        }
 633        self
 634    }
 635
 636    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 637    where
 638        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 639        Fut: 'static + Send + Future<Output = Result<()>>,
 640        M: EnvelopedMessage,
 641    {
 642        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 643        self
 644    }
 645
 646    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 647    where
 648        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 649        Fut: Send + Future<Output = Result<()>>,
 650        M: RequestMessage,
 651    {
 652        let handler = Arc::new(handler);
 653        self.add_handler(move |envelope, session| {
 654            let receipt = envelope.receipt();
 655            let handler = handler.clone();
 656            async move {
 657                let peer = session.peer.clone();
 658                let responded = Arc::new(AtomicBool::default());
 659                let response = Response {
 660                    peer: peer.clone(),
 661                    responded: responded.clone(),
 662                    receipt,
 663                };
 664                match (handler)(envelope.payload, response, session).await {
 665                    Ok(()) => {
 666                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 667                            Ok(())
 668                        } else {
 669                            Err(anyhow!("handler did not send a response"))?
 670                        }
 671                    }
 672                    Err(error) => {
 673                        let proto_err = match &error {
 674                            Error::Internal(err) => err.to_proto(),
 675                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 676                        };
 677                        peer.respond_with_error(receipt, proto_err)?;
 678                        Err(error)
 679                    }
 680                }
 681            }
 682        })
 683    }
 684
 685    pub fn handle_connection(
 686        self: &Arc<Self>,
 687        connection: Connection,
 688        address: String,
 689        principal: Principal,
 690        zed_version: ZedVersion,
 691        geoip_country_code: Option<String>,
 692        system_id: Option<String>,
 693        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 694        executor: Executor,
 695    ) -> impl Future<Output = ()> + use<> {
 696        let this = self.clone();
 697        let span = info_span!("handle connection", %address,
 698            connection_id=field::Empty,
 699            user_id=field::Empty,
 700            login=field::Empty,
 701            impersonator=field::Empty,
 702            geoip_country_code=field::Empty
 703        );
 704        principal.update_span(&span);
 705        if let Some(country_code) = geoip_country_code.as_ref() {
 706            span.record("geoip_country_code", country_code);
 707        }
 708
 709        let mut teardown = self.teardown.subscribe();
 710        async move {
 711            if *teardown.borrow() {
 712                tracing::error!("server is tearing down");
 713                return
 714            }
 715            let (connection_id, handle_io, mut incoming_rx) = this
 716                .peer
 717                .add_connection(connection, {
 718                    let executor = executor.clone();
 719                    move |duration| executor.sleep(duration)
 720                });
 721            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 722
 723            tracing::info!("connection opened");
 724
 725            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 726            let http_client = match ReqwestClient::user_agent(&user_agent) {
 727                Ok(http_client) => Arc::new(http_client),
 728                Err(error) => {
 729                    tracing::error!(?error, "failed to create HTTP client");
 730                    return;
 731                }
 732            };
 733
 734            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 735                    supermaven_admin_api_key.to_string(),
 736                    http_client.clone(),
 737                )));
 738
 739            let session = Session {
 740                principal: principal.clone(),
 741                connection_id,
 742                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 743                peer: this.peer.clone(),
 744                connection_pool: this.connection_pool.clone(),
 745                app_state: this.app_state.clone(),
 746                geoip_country_code,
 747                system_id,
 748                _executor: executor.clone(),
 749                supermaven_client,
 750            };
 751
 752            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 753                tracing::error!(?error, "failed to send initial client update");
 754                return;
 755            }
 756
 757            let handle_io = handle_io.fuse();
 758            futures::pin_mut!(handle_io);
 759
 760            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 761            // This prevents deadlocks when e.g., client A performs a request to client B and
 762            // client B performs a request to client A. If both clients stop processing further
 763            // messages until their respective request completes, they won't have a chance to
 764            // respond to the other client's request and cause a deadlock.
 765            //
 766            // This arrangement ensures we will attempt to process earlier messages first, but fall
 767            // back to processing messages arrived later in the spirit of making progress.
 768            let mut foreground_message_handlers = FuturesUnordered::new();
 769            let concurrent_handlers = Arc::new(Semaphore::new(256));
 770            loop {
 771                let next_message = async {
 772                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 773                    let message = incoming_rx.next().await;
 774                    (permit, message)
 775                }.fuse();
 776                futures::pin_mut!(next_message);
 777                futures::select_biased! {
 778                    _ = teardown.changed().fuse() => return,
 779                    result = handle_io => {
 780                        if let Err(error) = result {
 781                            tracing::error!(?error, "error handling I/O");
 782                        }
 783                        break;
 784                    }
 785                    _ = foreground_message_handlers.next() => {}
 786                    next_message = next_message => {
 787                        let (permit, message) = next_message;
 788                        if let Some(message) = message {
 789                            let type_name = message.payload_type_name();
 790                            // note: we copy all the fields from the parent span so we can query them in the logs.
 791                            // (https://github.com/tokio-rs/tracing/issues/2670).
 792                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 793                                user_id=field::Empty,
 794                                login=field::Empty,
 795                                impersonator=field::Empty,
 796                            );
 797                            principal.update_span(&span);
 798                            let span_enter = span.enter();
 799                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 800                                let is_background = message.is_background();
 801                                let handle_message = (handler)(message, session.clone());
 802                                drop(span_enter);
 803
 804                                let handle_message = async move {
 805                                    handle_message.await;
 806                                    drop(permit);
 807                                }.instrument(span);
 808                                if is_background {
 809                                    executor.spawn_detached(handle_message);
 810                                } else {
 811                                    foreground_message_handlers.push(handle_message);
 812                                }
 813                            } else {
 814                                tracing::error!("no message handler");
 815                            }
 816                        } else {
 817                            tracing::info!("connection closed");
 818                            break;
 819                        }
 820                    }
 821                }
 822            }
 823
 824            drop(foreground_message_handlers);
 825            tracing::info!("signing out");
 826            if let Err(error) = connection_lost(session, teardown, executor).await {
 827                tracing::error!(?error, "error signing out");
 828            }
 829
 830        }.instrument(span)
 831    }
 832
 833    async fn send_initial_client_update(
 834        &self,
 835        connection_id: ConnectionId,
 836        zed_version: ZedVersion,
 837        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 838        session: &Session,
 839    ) -> Result<()> {
 840        self.peer.send(
 841            connection_id,
 842            proto::Hello {
 843                peer_id: Some(connection_id.into()),
 844            },
 845        )?;
 846        tracing::info!("sent hello message");
 847        if let Some(send_connection_id) = send_connection_id.take() {
 848            let _ = send_connection_id.send(connection_id);
 849        }
 850
 851        match &session.principal {
 852            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 853                if !user.connected_once {
 854                    self.peer.send(connection_id, proto::ShowContacts {})?;
 855                    self.app_state
 856                        .db
 857                        .set_user_connected_once(user.id, true)
 858                        .await?;
 859                }
 860
 861                update_user_plan(session).await?;
 862
 863                let contacts = self.app_state.db.get_contacts(user.id).await?;
 864
 865                {
 866                    let mut pool = self.connection_pool.lock();
 867                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 868                    self.peer.send(
 869                        connection_id,
 870                        build_initial_contacts_update(contacts, &pool),
 871                    )?;
 872                }
 873
 874                if should_auto_subscribe_to_channels(zed_version) {
 875                    subscribe_user_to_channels(user.id, session).await?;
 876                }
 877
 878                if let Some(incoming_call) =
 879                    self.app_state.db.incoming_call_for_user(user.id).await?
 880                {
 881                    self.peer.send(connection_id, incoming_call)?;
 882                }
 883
 884                update_user_contacts(user.id, session).await?;
 885            }
 886        }
 887
 888        Ok(())
 889    }
 890
 891    pub async fn invite_code_redeemed(
 892        self: &Arc<Self>,
 893        inviter_id: UserId,
 894        invitee_id: UserId,
 895    ) -> Result<()> {
 896        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 897            if let Some(code) = &user.invite_code {
 898                let pool = self.connection_pool.lock();
 899                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 900                for connection_id in pool.user_connection_ids(inviter_id) {
 901                    self.peer.send(
 902                        connection_id,
 903                        proto::UpdateContacts {
 904                            contacts: vec![invitee_contact.clone()],
 905                            ..Default::default()
 906                        },
 907                    )?;
 908                    self.peer.send(
 909                        connection_id,
 910                        proto::UpdateInviteInfo {
 911                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 912                            count: user.invite_count as u32,
 913                        },
 914                    )?;
 915                }
 916            }
 917        }
 918        Ok(())
 919    }
 920
 921    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 923            if let Some(invite_code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                for connection_id in pool.user_connection_ids(user_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateInviteInfo {
 929                            url: format!(
 930                                "{}{}",
 931                                self.app_state.config.invite_link_prefix, invite_code
 932                            ),
 933                            count: user.invite_count as u32,
 934                        },
 935                    )?;
 936                }
 937            }
 938        }
 939        Ok(())
 940    }
 941
 942    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 943        let user = self
 944            .app_state
 945            .db
 946            .get_user_by_id(user_id)
 947            .await?
 948            .context("user not found")?;
 949
 950        let update_user_plan = make_update_user_plan_message(
 951            &user,
 952            user.admin,
 953            &self.app_state.db,
 954            self.app_state.llm_db.clone(),
 955        )
 956        .await?;
 957
 958        let pool = self.connection_pool.lock();
 959        for connection_id in pool.user_connection_ids(user_id) {
 960            self.peer
 961                .send(connection_id, update_user_plan.clone())
 962                .trace_err();
 963        }
 964
 965        Ok(())
 966    }
 967
 968    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 969        let pool = self.connection_pool.lock();
 970        for connection_id in pool.user_connection_ids(user_id) {
 971            self.peer
 972                .send(connection_id, proto::RefreshLlmToken {})
 973                .trace_err();
 974        }
 975    }
 976
 977    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 978        ServerSnapshot {
 979            connection_pool: ConnectionPoolGuard {
 980                guard: self.connection_pool.lock(),
 981                _not_send: PhantomData,
 982            },
 983            peer: &self.peer,
 984        }
 985    }
 986}
 987
 988impl Deref for ConnectionPoolGuard<'_> {
 989    type Target = ConnectionPool;
 990
 991    fn deref(&self) -> &Self::Target {
 992        &self.guard
 993    }
 994}
 995
 996impl DerefMut for ConnectionPoolGuard<'_> {
 997    fn deref_mut(&mut self) -> &mut Self::Target {
 998        &mut self.guard
 999    }
1000}
1001
1002impl Drop for ConnectionPoolGuard<'_> {
1003    fn drop(&mut self) {
1004        #[cfg(test)]
1005        self.check_invariants();
1006    }
1007}
1008
1009fn broadcast<F>(
1010    sender_id: Option<ConnectionId>,
1011    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1012    mut f: F,
1013) where
1014    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1015{
1016    for receiver_id in receiver_ids {
1017        if Some(receiver_id) != sender_id {
1018            if let Err(error) = f(receiver_id) {
1019                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1020            }
1021        }
1022    }
1023}
1024
1025pub struct ProtocolVersion(u32);
1026
1027impl Header for ProtocolVersion {
1028    fn name() -> &'static HeaderName {
1029        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1030        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1031    }
1032
1033    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1034    where
1035        Self: Sized,
1036        I: Iterator<Item = &'i axum::http::HeaderValue>,
1037    {
1038        let version = values
1039            .next()
1040            .ok_or_else(axum::headers::Error::invalid)?
1041            .to_str()
1042            .map_err(|_| axum::headers::Error::invalid())?
1043            .parse()
1044            .map_err(|_| axum::headers::Error::invalid())?;
1045        Ok(Self(version))
1046    }
1047
1048    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1049        values.extend([self.0.to_string().parse().unwrap()]);
1050    }
1051}
1052
1053pub struct AppVersionHeader(SemanticVersion);
1054impl Header for AppVersionHeader {
1055    fn name() -> &'static HeaderName {
1056        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1057        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1058    }
1059
1060    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1061    where
1062        Self: Sized,
1063        I: Iterator<Item = &'i axum::http::HeaderValue>,
1064    {
1065        let version = values
1066            .next()
1067            .ok_or_else(axum::headers::Error::invalid)?
1068            .to_str()
1069            .map_err(|_| axum::headers::Error::invalid())?
1070            .parse()
1071            .map_err(|_| axum::headers::Error::invalid())?;
1072        Ok(Self(version))
1073    }
1074
1075    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1076        values.extend([self.0.to_string().parse().unwrap()]);
1077    }
1078}
1079
1080pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1081    Router::new()
1082        .route("/rpc", get(handle_websocket_request))
1083        .layer(
1084            ServiceBuilder::new()
1085                .layer(Extension(server.app_state.clone()))
1086                .layer(middleware::from_fn(auth::validate_header)),
1087        )
1088        .route("/metrics", get(handle_metrics))
1089        .layer(Extension(server))
1090}
1091
1092pub async fn handle_websocket_request(
1093    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1094    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1095    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1096    Extension(server): Extension<Arc<Server>>,
1097    Extension(principal): Extension<Principal>,
1098    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1099    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1100    ws: WebSocketUpgrade,
1101) -> axum::response::Response {
1102    if protocol_version != rpc::PROTOCOL_VERSION {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "client must be upgraded".to_string(),
1106        )
1107            .into_response();
1108    }
1109
1110    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "no version header found".to_string(),
1114        )
1115            .into_response();
1116    };
1117
1118    if !version.can_collaborate() {
1119        return (
1120            StatusCode::UPGRADE_REQUIRED,
1121            "client must be upgraded".to_string(),
1122        )
1123            .into_response();
1124    }
1125
1126    let socket_address = socket_address.to_string();
1127    ws.on_upgrade(move |socket| {
1128        let socket = socket
1129            .map_ok(to_tungstenite_message)
1130            .err_into()
1131            .with(|message| async move { to_axum_message(message) });
1132        let connection = Connection::new(Box::pin(socket));
1133        async move {
1134            server
1135                .handle_connection(
1136                    connection,
1137                    socket_address,
1138                    principal,
1139                    version,
1140                    country_code_header.map(|header| header.to_string()),
1141                    system_id_header.map(|header| header.to_string()),
1142                    None,
1143                    Executor::Production,
1144                )
1145                .await;
1146        }
1147    })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152    let connections_metric = CONNECTIONS_METRIC
1153        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155    let connections = server
1156        .connection_pool
1157        .lock()
1158        .connections()
1159        .filter(|connection| !connection.admin)
1160        .count();
1161    connections_metric.set(connections as _);
1162
1163    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165        register_int_gauge!(
1166            "shared_projects",
1167            "number of open projects with one or more guests"
1168        )
1169        .unwrap()
1170    });
1171
1172    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173    shared_projects_metric.set(shared_projects as _);
1174
1175    let encoder = prometheus::TextEncoder::new();
1176    let metric_families = prometheus::gather();
1177    let encoded_metrics = encoder
1178        .encode_to_string(&metric_families)
1179        .map_err(|err| anyhow!("{err}"))?;
1180    Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185    session: Session,
1186    mut teardown: watch::Receiver<bool>,
1187    executor: Executor,
1188) -> Result<()> {
1189    session.peer.disconnect(session.connection_id);
1190    session
1191        .connection_pool()
1192        .await
1193        .remove_connection(session.connection_id)?;
1194
1195    session
1196        .db()
1197        .await
1198        .connection_lost(session.connection_id)
1199        .await
1200        .trace_err();
1201
1202    futures::select_biased! {
1203        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204
1205            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206            leave_room_for_session(&session, session.connection_id).await.trace_err();
1207            leave_channel_buffers_for_session(&session)
1208                .await
1209                .trace_err();
1210
1211            if !session
1212                .connection_pool()
1213                .await
1214                .is_user_online(session.user_id())
1215            {
1216                let db = session.db().await;
1217                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218                    room_updated(&room, &session.peer);
1219                }
1220            }
1221
1222            update_user_contacts(session.user_id(), &session).await?;
1223        },
1224        _ = teardown.changed().fuse() => {}
1225    }
1226
1227    Ok(())
1228}
1229
1230/// Acknowledges a ping from a client, used to keep the connection alive.
1231async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1232    response.send(proto::Ack {})?;
1233    Ok(())
1234}
1235
1236/// Creates a new room for calling (outside of channels)
1237async fn create_room(
1238    _request: proto::CreateRoom,
1239    response: Response<proto::CreateRoom>,
1240    session: Session,
1241) -> Result<()> {
1242    let livekit_room = nanoid::nanoid!(30);
1243
1244    let live_kit_connection_info = util::maybe!(async {
1245        let live_kit = session.app_state.livekit_client.as_ref();
1246        let live_kit = live_kit?;
1247        let user_id = session.user_id().to_string();
1248
1249        let token = live_kit
1250            .room_token(&livekit_room, &user_id.to_string())
1251            .trace_err()?;
1252
1253        Some(proto::LiveKitConnectionInfo {
1254            server_url: live_kit.url().into(),
1255            token,
1256            can_publish: true,
1257        })
1258    })
1259    .await;
1260
1261    let room = session
1262        .db()
1263        .await
1264        .create_room(session.user_id(), session.connection_id, &livekit_room)
1265        .await?;
1266
1267    response.send(proto::CreateRoomResponse {
1268        room: Some(room.clone()),
1269        live_kit_connection_info,
1270    })?;
1271
1272    update_user_contacts(session.user_id(), &session).await?;
1273    Ok(())
1274}
1275
1276/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1277async fn join_room(
1278    request: proto::JoinRoom,
1279    response: Response<proto::JoinRoom>,
1280    session: Session,
1281) -> Result<()> {
1282    let room_id = RoomId::from_proto(request.id);
1283
1284    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1285
1286    if let Some(channel_id) = channel_id {
1287        return join_channel_internal(channel_id, Box::new(response), session).await;
1288    }
1289
1290    let joined_room = {
1291        let room = session
1292            .db()
1293            .await
1294            .join_room(room_id, session.user_id(), session.connection_id)
1295            .await?;
1296        room_updated(&room.room, &session.peer);
1297        room.into_inner()
1298    };
1299
1300    for connection_id in session
1301        .connection_pool()
1302        .await
1303        .user_connection_ids(session.user_id())
1304    {
1305        session
1306            .peer
1307            .send(
1308                connection_id,
1309                proto::CallCanceled {
1310                    room_id: room_id.to_proto(),
1311                },
1312            )
1313            .trace_err();
1314    }
1315
1316    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1317    {
1318        live_kit
1319            .room_token(
1320                &joined_room.room.livekit_room,
1321                &session.user_id().to_string(),
1322            )
1323            .trace_err()
1324            .map(|token| proto::LiveKitConnectionInfo {
1325                server_url: live_kit.url().into(),
1326                token,
1327                can_publish: true,
1328            })
1329    } else {
1330        None
1331    };
1332
1333    response.send(proto::JoinRoomResponse {
1334        room: Some(joined_room.room),
1335        channel_id: None,
1336        live_kit_connection_info,
1337    })?;
1338
1339    update_user_contacts(session.user_id(), &session).await?;
1340    Ok(())
1341}
1342
1343/// Rejoin room is used to reconnect to a room after connection errors.
1344async fn rejoin_room(
1345    request: proto::RejoinRoom,
1346    response: Response<proto::RejoinRoom>,
1347    session: Session,
1348) -> Result<()> {
1349    let room;
1350    let channel;
1351    {
1352        let mut rejoined_room = session
1353            .db()
1354            .await
1355            .rejoin_room(request, session.user_id(), session.connection_id)
1356            .await?;
1357
1358        response.send(proto::RejoinRoomResponse {
1359            room: Some(rejoined_room.room.clone()),
1360            reshared_projects: rejoined_room
1361                .reshared_projects
1362                .iter()
1363                .map(|project| proto::ResharedProject {
1364                    id: project.id.to_proto(),
1365                    collaborators: project
1366                        .collaborators
1367                        .iter()
1368                        .map(|collaborator| collaborator.to_proto())
1369                        .collect(),
1370                })
1371                .collect(),
1372            rejoined_projects: rejoined_room
1373                .rejoined_projects
1374                .iter()
1375                .map(|rejoined_project| rejoined_project.to_proto())
1376                .collect(),
1377        })?;
1378        room_updated(&rejoined_room.room, &session.peer);
1379
1380        for project in &rejoined_room.reshared_projects {
1381            for collaborator in &project.collaborators {
1382                session
1383                    .peer
1384                    .send(
1385                        collaborator.connection_id,
1386                        proto::UpdateProjectCollaborator {
1387                            project_id: project.id.to_proto(),
1388                            old_peer_id: Some(project.old_connection_id.into()),
1389                            new_peer_id: Some(session.connection_id.into()),
1390                        },
1391                    )
1392                    .trace_err();
1393            }
1394
1395            broadcast(
1396                Some(session.connection_id),
1397                project
1398                    .collaborators
1399                    .iter()
1400                    .map(|collaborator| collaborator.connection_id),
1401                |connection_id| {
1402                    session.peer.forward_send(
1403                        session.connection_id,
1404                        connection_id,
1405                        proto::UpdateProject {
1406                            project_id: project.id.to_proto(),
1407                            worktrees: project.worktrees.clone(),
1408                        },
1409                    )
1410                },
1411            );
1412        }
1413
1414        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1415
1416        let rejoined_room = rejoined_room.into_inner();
1417
1418        room = rejoined_room.room;
1419        channel = rejoined_room.channel;
1420    }
1421
1422    if let Some(channel) = channel {
1423        channel_updated(
1424            &channel,
1425            &room,
1426            &session.peer,
1427            &*session.connection_pool().await,
1428        );
1429    }
1430
1431    update_user_contacts(session.user_id(), &session).await?;
1432    Ok(())
1433}
1434
1435fn notify_rejoined_projects(
1436    rejoined_projects: &mut Vec<RejoinedProject>,
1437    session: &Session,
1438) -> Result<()> {
1439    for project in rejoined_projects.iter() {
1440        for collaborator in &project.collaborators {
1441            session
1442                .peer
1443                .send(
1444                    collaborator.connection_id,
1445                    proto::UpdateProjectCollaborator {
1446                        project_id: project.id.to_proto(),
1447                        old_peer_id: Some(project.old_connection_id.into()),
1448                        new_peer_id: Some(session.connection_id.into()),
1449                    },
1450                )
1451                .trace_err();
1452        }
1453    }
1454
1455    for project in rejoined_projects {
1456        for worktree in mem::take(&mut project.worktrees) {
1457            // Stream this worktree's entries.
1458            let message = proto::UpdateWorktree {
1459                project_id: project.id.to_proto(),
1460                worktree_id: worktree.id,
1461                abs_path: worktree.abs_path.clone(),
1462                root_name: worktree.root_name,
1463                updated_entries: worktree.updated_entries,
1464                removed_entries: worktree.removed_entries,
1465                scan_id: worktree.scan_id,
1466                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1467                updated_repositories: worktree.updated_repositories,
1468                removed_repositories: worktree.removed_repositories,
1469            };
1470            for update in proto::split_worktree_update(message) {
1471                session.peer.send(session.connection_id, update)?;
1472            }
1473
1474            // Stream this worktree's diagnostics.
1475            for summary in worktree.diagnostic_summaries {
1476                session.peer.send(
1477                    session.connection_id,
1478                    proto::UpdateDiagnosticSummary {
1479                        project_id: project.id.to_proto(),
1480                        worktree_id: worktree.id,
1481                        summary: Some(summary),
1482                    },
1483                )?;
1484            }
1485
1486            for settings_file in worktree.settings_files {
1487                session.peer.send(
1488                    session.connection_id,
1489                    proto::UpdateWorktreeSettings {
1490                        project_id: project.id.to_proto(),
1491                        worktree_id: worktree.id,
1492                        path: settings_file.path,
1493                        content: Some(settings_file.content),
1494                        kind: Some(settings_file.kind.to_proto().into()),
1495                    },
1496                )?;
1497            }
1498        }
1499
1500        for repository in mem::take(&mut project.updated_repositories) {
1501            for update in split_repository_update(repository) {
1502                session.peer.send(session.connection_id, update)?;
1503            }
1504        }
1505
1506        for id in mem::take(&mut project.removed_repositories) {
1507            session.peer.send(
1508                session.connection_id,
1509                proto::RemoveRepository {
1510                    project_id: project.id.to_proto(),
1511                    id,
1512                },
1513            )?;
1514        }
1515    }
1516
1517    Ok(())
1518}
1519
1520/// leave room disconnects from the room.
1521async fn leave_room(
1522    _: proto::LeaveRoom,
1523    response: Response<proto::LeaveRoom>,
1524    session: Session,
1525) -> Result<()> {
1526    leave_room_for_session(&session, session.connection_id).await?;
1527    response.send(proto::Ack {})?;
1528    Ok(())
1529}
1530
1531/// Updates the permissions of someone else in the room.
1532async fn set_room_participant_role(
1533    request: proto::SetRoomParticipantRole,
1534    response: Response<proto::SetRoomParticipantRole>,
1535    session: Session,
1536) -> Result<()> {
1537    let user_id = UserId::from_proto(request.user_id);
1538    let role = ChannelRole::from(request.role());
1539
1540    let (livekit_room, can_publish) = {
1541        let room = session
1542            .db()
1543            .await
1544            .set_room_participant_role(
1545                session.user_id(),
1546                RoomId::from_proto(request.room_id),
1547                user_id,
1548                role,
1549            )
1550            .await?;
1551
1552        let livekit_room = room.livekit_room.clone();
1553        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1554        room_updated(&room, &session.peer);
1555        (livekit_room, can_publish)
1556    };
1557
1558    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1559        live_kit
1560            .update_participant(
1561                livekit_room.clone(),
1562                request.user_id.to_string(),
1563                livekit_api::proto::ParticipantPermission {
1564                    can_subscribe: true,
1565                    can_publish,
1566                    can_publish_data: can_publish,
1567                    hidden: false,
1568                    recorder: false,
1569                },
1570            )
1571            .await
1572            .trace_err();
1573    }
1574
1575    response.send(proto::Ack {})?;
1576    Ok(())
1577}
1578
1579/// Call someone else into the current room
1580async fn call(
1581    request: proto::Call,
1582    response: Response<proto::Call>,
1583    session: Session,
1584) -> Result<()> {
1585    let room_id = RoomId::from_proto(request.room_id);
1586    let calling_user_id = session.user_id();
1587    let calling_connection_id = session.connection_id;
1588    let called_user_id = UserId::from_proto(request.called_user_id);
1589    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1590    if !session
1591        .db()
1592        .await
1593        .has_contact(calling_user_id, called_user_id)
1594        .await?
1595    {
1596        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1597    }
1598
1599    let incoming_call = {
1600        let (room, incoming_call) = &mut *session
1601            .db()
1602            .await
1603            .call(
1604                room_id,
1605                calling_user_id,
1606                calling_connection_id,
1607                called_user_id,
1608                initial_project_id,
1609            )
1610            .await?;
1611        room_updated(room, &session.peer);
1612        mem::take(incoming_call)
1613    };
1614    update_user_contacts(called_user_id, &session).await?;
1615
1616    let mut calls = session
1617        .connection_pool()
1618        .await
1619        .user_connection_ids(called_user_id)
1620        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1621        .collect::<FuturesUnordered<_>>();
1622
1623    while let Some(call_response) = calls.next().await {
1624        match call_response.as_ref() {
1625            Ok(_) => {
1626                response.send(proto::Ack {})?;
1627                return Ok(());
1628            }
1629            Err(_) => {
1630                call_response.trace_err();
1631            }
1632        }
1633    }
1634
1635    {
1636        let room = session
1637            .db()
1638            .await
1639            .call_failed(room_id, called_user_id)
1640            .await?;
1641        room_updated(&room, &session.peer);
1642    }
1643    update_user_contacts(called_user_id, &session).await?;
1644
1645    Err(anyhow!("failed to ring user"))?
1646}
1647
1648/// Cancel an outgoing call.
1649async fn cancel_call(
1650    request: proto::CancelCall,
1651    response: Response<proto::CancelCall>,
1652    session: Session,
1653) -> Result<()> {
1654    let called_user_id = UserId::from_proto(request.called_user_id);
1655    let room_id = RoomId::from_proto(request.room_id);
1656    {
1657        let room = session
1658            .db()
1659            .await
1660            .cancel_call(room_id, session.connection_id, called_user_id)
1661            .await?;
1662        room_updated(&room, &session.peer);
1663    }
1664
1665    for connection_id in session
1666        .connection_pool()
1667        .await
1668        .user_connection_ids(called_user_id)
1669    {
1670        session
1671            .peer
1672            .send(
1673                connection_id,
1674                proto::CallCanceled {
1675                    room_id: room_id.to_proto(),
1676                },
1677            )
1678            .trace_err();
1679    }
1680    response.send(proto::Ack {})?;
1681
1682    update_user_contacts(called_user_id, &session).await?;
1683    Ok(())
1684}
1685
1686/// Decline an incoming call.
1687async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1688    let room_id = RoomId::from_proto(message.room_id);
1689    {
1690        let room = session
1691            .db()
1692            .await
1693            .decline_call(Some(room_id), session.user_id())
1694            .await?
1695            .context("declining call")?;
1696        room_updated(&room, &session.peer);
1697    }
1698
1699    for connection_id in session
1700        .connection_pool()
1701        .await
1702        .user_connection_ids(session.user_id())
1703    {
1704        session
1705            .peer
1706            .send(
1707                connection_id,
1708                proto::CallCanceled {
1709                    room_id: room_id.to_proto(),
1710                },
1711            )
1712            .trace_err();
1713    }
1714    update_user_contacts(session.user_id(), &session).await?;
1715    Ok(())
1716}
1717
1718/// Updates other participants in the room with your current location.
1719async fn update_participant_location(
1720    request: proto::UpdateParticipantLocation,
1721    response: Response<proto::UpdateParticipantLocation>,
1722    session: Session,
1723) -> Result<()> {
1724    let room_id = RoomId::from_proto(request.room_id);
1725    let location = request.location.context("invalid location")?;
1726
1727    let db = session.db().await;
1728    let room = db
1729        .update_room_participant_location(room_id, session.connection_id, location)
1730        .await?;
1731
1732    room_updated(&room, &session.peer);
1733    response.send(proto::Ack {})?;
1734    Ok(())
1735}
1736
1737/// Share a project into the room.
1738async fn share_project(
1739    request: proto::ShareProject,
1740    response: Response<proto::ShareProject>,
1741    session: Session,
1742) -> Result<()> {
1743    let (project_id, room) = &*session
1744        .db()
1745        .await
1746        .share_project(
1747            RoomId::from_proto(request.room_id),
1748            session.connection_id,
1749            &request.worktrees,
1750            request.is_ssh_project,
1751        )
1752        .await?;
1753    response.send(proto::ShareProjectResponse {
1754        project_id: project_id.to_proto(),
1755    })?;
1756    room_updated(room, &session.peer);
1757
1758    Ok(())
1759}
1760
1761/// Unshare a project from the room.
1762async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1763    let project_id = ProjectId::from_proto(message.project_id);
1764    unshare_project_internal(project_id, session.connection_id, &session).await
1765}
1766
1767async fn unshare_project_internal(
1768    project_id: ProjectId,
1769    connection_id: ConnectionId,
1770    session: &Session,
1771) -> Result<()> {
1772    let delete = {
1773        let room_guard = session
1774            .db()
1775            .await
1776            .unshare_project(project_id, connection_id)
1777            .await?;
1778
1779        let (delete, room, guest_connection_ids) = &*room_guard;
1780
1781        let message = proto::UnshareProject {
1782            project_id: project_id.to_proto(),
1783        };
1784
1785        broadcast(
1786            Some(connection_id),
1787            guest_connection_ids.iter().copied(),
1788            |conn_id| session.peer.send(conn_id, message.clone()),
1789        );
1790        if let Some(room) = room {
1791            room_updated(room, &session.peer);
1792        }
1793
1794        *delete
1795    };
1796
1797    if delete {
1798        let db = session.db().await;
1799        db.delete_project(project_id).await?;
1800    }
1801
1802    Ok(())
1803}
1804
1805/// Join someone elses shared project.
1806async fn join_project(
1807    request: proto::JoinProject,
1808    response: Response<proto::JoinProject>,
1809    session: Session,
1810) -> Result<()> {
1811    let project_id = ProjectId::from_proto(request.project_id);
1812
1813    tracing::info!(%project_id, "join project");
1814
1815    let db = session.db().await;
1816    let (project, replica_id) = &mut *db
1817        .join_project(project_id, session.connection_id, session.user_id())
1818        .await?;
1819    drop(db);
1820    tracing::info!(%project_id, "join remote project");
1821    join_project_internal(response, session, project, replica_id)
1822}
1823
1824trait JoinProjectInternalResponse {
1825    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1826}
1827impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1828    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1829        Response::<proto::JoinProject>::send(self, result)
1830    }
1831}
1832
1833fn join_project_internal(
1834    response: impl JoinProjectInternalResponse,
1835    session: Session,
1836    project: &mut Project,
1837    replica_id: &ReplicaId,
1838) -> Result<()> {
1839    let collaborators = project
1840        .collaborators
1841        .iter()
1842        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1843        .map(|collaborator| collaborator.to_proto())
1844        .collect::<Vec<_>>();
1845    let project_id = project.id;
1846    let guest_user_id = session.user_id();
1847
1848    let worktrees = project
1849        .worktrees
1850        .iter()
1851        .map(|(id, worktree)| proto::WorktreeMetadata {
1852            id: *id,
1853            root_name: worktree.root_name.clone(),
1854            visible: worktree.visible,
1855            abs_path: worktree.abs_path.clone(),
1856        })
1857        .collect::<Vec<_>>();
1858
1859    let add_project_collaborator = proto::AddProjectCollaborator {
1860        project_id: project_id.to_proto(),
1861        collaborator: Some(proto::Collaborator {
1862            peer_id: Some(session.connection_id.into()),
1863            replica_id: replica_id.0 as u32,
1864            user_id: guest_user_id.to_proto(),
1865            is_host: false,
1866        }),
1867    };
1868
1869    for collaborator in &collaborators {
1870        session
1871            .peer
1872            .send(
1873                collaborator.peer_id.unwrap().into(),
1874                add_project_collaborator.clone(),
1875            )
1876            .trace_err();
1877    }
1878
1879    // First, we send the metadata associated with each worktree.
1880    response.send(proto::JoinProjectResponse {
1881        project_id: project.id.0 as u64,
1882        worktrees: worktrees.clone(),
1883        replica_id: replica_id.0 as u32,
1884        collaborators: collaborators.clone(),
1885        language_servers: project.language_servers.clone(),
1886        role: project.role.into(),
1887    })?;
1888
1889    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1890        // Stream this worktree's entries.
1891        let message = proto::UpdateWorktree {
1892            project_id: project_id.to_proto(),
1893            worktree_id,
1894            abs_path: worktree.abs_path.clone(),
1895            root_name: worktree.root_name,
1896            updated_entries: worktree.entries,
1897            removed_entries: Default::default(),
1898            scan_id: worktree.scan_id,
1899            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1900            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1901            removed_repositories: Default::default(),
1902        };
1903        for update in proto::split_worktree_update(message) {
1904            session.peer.send(session.connection_id, update.clone())?;
1905        }
1906
1907        // Stream this worktree's diagnostics.
1908        for summary in worktree.diagnostic_summaries {
1909            session.peer.send(
1910                session.connection_id,
1911                proto::UpdateDiagnosticSummary {
1912                    project_id: project_id.to_proto(),
1913                    worktree_id: worktree.id,
1914                    summary: Some(summary),
1915                },
1916            )?;
1917        }
1918
1919        for settings_file in worktree.settings_files {
1920            session.peer.send(
1921                session.connection_id,
1922                proto::UpdateWorktreeSettings {
1923                    project_id: project_id.to_proto(),
1924                    worktree_id: worktree.id,
1925                    path: settings_file.path,
1926                    content: Some(settings_file.content),
1927                    kind: Some(settings_file.kind.to_proto() as i32),
1928                },
1929            )?;
1930        }
1931    }
1932
1933    for repository in mem::take(&mut project.repositories) {
1934        for update in split_repository_update(repository) {
1935            session.peer.send(session.connection_id, update)?;
1936        }
1937    }
1938
1939    for language_server in &project.language_servers {
1940        session.peer.send(
1941            session.connection_id,
1942            proto::UpdateLanguageServer {
1943                project_id: project_id.to_proto(),
1944                language_server_id: language_server.id,
1945                variant: Some(
1946                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1947                        proto::LspDiskBasedDiagnosticsUpdated {},
1948                    ),
1949                ),
1950            },
1951        )?;
1952    }
1953
1954    Ok(())
1955}
1956
1957/// Leave someone elses shared project.
1958async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1959    let sender_id = session.connection_id;
1960    let project_id = ProjectId::from_proto(request.project_id);
1961    let db = session.db().await;
1962
1963    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1964    tracing::info!(
1965        %project_id,
1966        "leave project"
1967    );
1968
1969    project_left(project, &session);
1970    if let Some(room) = room {
1971        room_updated(room, &session.peer);
1972    }
1973
1974    Ok(())
1975}
1976
1977/// Updates other participants with changes to the project
1978async fn update_project(
1979    request: proto::UpdateProject,
1980    response: Response<proto::UpdateProject>,
1981    session: Session,
1982) -> Result<()> {
1983    let project_id = ProjectId::from_proto(request.project_id);
1984    let (room, guest_connection_ids) = &*session
1985        .db()
1986        .await
1987        .update_project(project_id, session.connection_id, &request.worktrees)
1988        .await?;
1989    broadcast(
1990        Some(session.connection_id),
1991        guest_connection_ids.iter().copied(),
1992        |connection_id| {
1993            session
1994                .peer
1995                .forward_send(session.connection_id, connection_id, request.clone())
1996        },
1997    );
1998    if let Some(room) = room {
1999        room_updated(room, &session.peer);
2000    }
2001    response.send(proto::Ack {})?;
2002
2003    Ok(())
2004}
2005
2006/// Updates other participants with changes to the worktree
2007async fn update_worktree(
2008    request: proto::UpdateWorktree,
2009    response: Response<proto::UpdateWorktree>,
2010    session: Session,
2011) -> Result<()> {
2012    let guest_connection_ids = session
2013        .db()
2014        .await
2015        .update_worktree(&request, session.connection_id)
2016        .await?;
2017
2018    broadcast(
2019        Some(session.connection_id),
2020        guest_connection_ids.iter().copied(),
2021        |connection_id| {
2022            session
2023                .peer
2024                .forward_send(session.connection_id, connection_id, request.clone())
2025        },
2026    );
2027    response.send(proto::Ack {})?;
2028    Ok(())
2029}
2030
2031async fn update_repository(
2032    request: proto::UpdateRepository,
2033    response: Response<proto::UpdateRepository>,
2034    session: Session,
2035) -> Result<()> {
2036    let guest_connection_ids = session
2037        .db()
2038        .await
2039        .update_repository(&request, session.connection_id)
2040        .await?;
2041
2042    broadcast(
2043        Some(session.connection_id),
2044        guest_connection_ids.iter().copied(),
2045        |connection_id| {
2046            session
2047                .peer
2048                .forward_send(session.connection_id, connection_id, request.clone())
2049        },
2050    );
2051    response.send(proto::Ack {})?;
2052    Ok(())
2053}
2054
2055async fn remove_repository(
2056    request: proto::RemoveRepository,
2057    response: Response<proto::RemoveRepository>,
2058    session: Session,
2059) -> Result<()> {
2060    let guest_connection_ids = session
2061        .db()
2062        .await
2063        .remove_repository(&request, session.connection_id)
2064        .await?;
2065
2066    broadcast(
2067        Some(session.connection_id),
2068        guest_connection_ids.iter().copied(),
2069        |connection_id| {
2070            session
2071                .peer
2072                .forward_send(session.connection_id, connection_id, request.clone())
2073        },
2074    );
2075    response.send(proto::Ack {})?;
2076    Ok(())
2077}
2078
2079/// Updates other participants with changes to the diagnostics
2080async fn update_diagnostic_summary(
2081    message: proto::UpdateDiagnosticSummary,
2082    session: Session,
2083) -> Result<()> {
2084    let guest_connection_ids = session
2085        .db()
2086        .await
2087        .update_diagnostic_summary(&message, session.connection_id)
2088        .await?;
2089
2090    broadcast(
2091        Some(session.connection_id),
2092        guest_connection_ids.iter().copied(),
2093        |connection_id| {
2094            session
2095                .peer
2096                .forward_send(session.connection_id, connection_id, message.clone())
2097        },
2098    );
2099
2100    Ok(())
2101}
2102
2103/// Updates other participants with changes to the worktree settings
2104async fn update_worktree_settings(
2105    message: proto::UpdateWorktreeSettings,
2106    session: Session,
2107) -> Result<()> {
2108    let guest_connection_ids = session
2109        .db()
2110        .await
2111        .update_worktree_settings(&message, session.connection_id)
2112        .await?;
2113
2114    broadcast(
2115        Some(session.connection_id),
2116        guest_connection_ids.iter().copied(),
2117        |connection_id| {
2118            session
2119                .peer
2120                .forward_send(session.connection_id, connection_id, message.clone())
2121        },
2122    );
2123
2124    Ok(())
2125}
2126
2127/// Notify other participants that a language server has started.
2128async fn start_language_server(
2129    request: proto::StartLanguageServer,
2130    session: Session,
2131) -> Result<()> {
2132    let guest_connection_ids = session
2133        .db()
2134        .await
2135        .start_language_server(&request, session.connection_id)
2136        .await?;
2137
2138    broadcast(
2139        Some(session.connection_id),
2140        guest_connection_ids.iter().copied(),
2141        |connection_id| {
2142            session
2143                .peer
2144                .forward_send(session.connection_id, connection_id, request.clone())
2145        },
2146    );
2147    Ok(())
2148}
2149
2150/// Notify other participants that a language server has changed.
2151async fn update_language_server(
2152    request: proto::UpdateLanguageServer,
2153    session: Session,
2154) -> Result<()> {
2155    let project_id = ProjectId::from_proto(request.project_id);
2156    let project_connection_ids = session
2157        .db()
2158        .await
2159        .project_connection_ids(project_id, session.connection_id, true)
2160        .await?;
2161    broadcast(
2162        Some(session.connection_id),
2163        project_connection_ids.iter().copied(),
2164        |connection_id| {
2165            session
2166                .peer
2167                .forward_send(session.connection_id, connection_id, request.clone())
2168        },
2169    );
2170    Ok(())
2171}
2172
2173/// forward a project request to the host. These requests should be read only
2174/// as guests are allowed to send them.
2175async fn forward_read_only_project_request<T>(
2176    request: T,
2177    response: Response<T>,
2178    session: Session,
2179) -> Result<()>
2180where
2181    T: EntityMessage + RequestMessage,
2182{
2183    let project_id = ProjectId::from_proto(request.remote_entity_id());
2184    let host_connection_id = session
2185        .db()
2186        .await
2187        .host_for_read_only_project_request(project_id, session.connection_id)
2188        .await?;
2189    let payload = session
2190        .peer
2191        .forward_request(session.connection_id, host_connection_id, request)
2192        .await?;
2193    response.send(payload)?;
2194    Ok(())
2195}
2196
2197async fn forward_find_search_candidates_request(
2198    request: proto::FindSearchCandidates,
2199    response: Response<proto::FindSearchCandidates>,
2200    session: Session,
2201) -> Result<()> {
2202    let project_id = ProjectId::from_proto(request.remote_entity_id());
2203    let host_connection_id = session
2204        .db()
2205        .await
2206        .host_for_read_only_project_request(project_id, session.connection_id)
2207        .await?;
2208    let payload = session
2209        .peer
2210        .forward_request(session.connection_id, host_connection_id, request)
2211        .await?;
2212    response.send(payload)?;
2213    Ok(())
2214}
2215
2216/// forward a project request to the host. These requests are disallowed
2217/// for guests.
2218async fn forward_mutating_project_request<T>(
2219    request: T,
2220    response: Response<T>,
2221    session: Session,
2222) -> Result<()>
2223where
2224    T: EntityMessage + RequestMessage,
2225{
2226    let project_id = ProjectId::from_proto(request.remote_entity_id());
2227
2228    let host_connection_id = session
2229        .db()
2230        .await
2231        .host_for_mutating_project_request(project_id, session.connection_id)
2232        .await?;
2233    let payload = session
2234        .peer
2235        .forward_request(session.connection_id, host_connection_id, request)
2236        .await?;
2237    response.send(payload)?;
2238    Ok(())
2239}
2240
2241/// Notify other participants that a new buffer has been created
2242async fn create_buffer_for_peer(
2243    request: proto::CreateBufferForPeer,
2244    session: Session,
2245) -> Result<()> {
2246    session
2247        .db()
2248        .await
2249        .check_user_is_project_host(
2250            ProjectId::from_proto(request.project_id),
2251            session.connection_id,
2252        )
2253        .await?;
2254    let peer_id = request.peer_id.context("invalid peer id")?;
2255    session
2256        .peer
2257        .forward_send(session.connection_id, peer_id.into(), request)?;
2258    Ok(())
2259}
2260
2261/// Notify other participants that a buffer has been updated. This is
2262/// allowed for guests as long as the update is limited to selections.
2263async fn update_buffer(
2264    request: proto::UpdateBuffer,
2265    response: Response<proto::UpdateBuffer>,
2266    session: Session,
2267) -> Result<()> {
2268    let project_id = ProjectId::from_proto(request.project_id);
2269    let mut capability = Capability::ReadOnly;
2270
2271    for op in request.operations.iter() {
2272        match op.variant {
2273            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2274            Some(_) => capability = Capability::ReadWrite,
2275        }
2276    }
2277
2278    let host = {
2279        let guard = session
2280            .db()
2281            .await
2282            .connections_for_buffer_update(project_id, session.connection_id, capability)
2283            .await?;
2284
2285        let (host, guests) = &*guard;
2286
2287        broadcast(
2288            Some(session.connection_id),
2289            guests.clone(),
2290            |connection_id| {
2291                session
2292                    .peer
2293                    .forward_send(session.connection_id, connection_id, request.clone())
2294            },
2295        );
2296
2297        *host
2298    };
2299
2300    if host != session.connection_id {
2301        session
2302            .peer
2303            .forward_request(session.connection_id, host, request.clone())
2304            .await?;
2305    }
2306
2307    response.send(proto::Ack {})?;
2308    Ok(())
2309}
2310
2311async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2312    let project_id = ProjectId::from_proto(message.project_id);
2313
2314    let operation = message.operation.as_ref().context("invalid operation")?;
2315    let capability = match operation.variant.as_ref() {
2316        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2317            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2318                match buffer_op.variant {
2319                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2320                        Capability::ReadOnly
2321                    }
2322                    _ => Capability::ReadWrite,
2323                }
2324            } else {
2325                Capability::ReadWrite
2326            }
2327        }
2328        Some(_) => Capability::ReadWrite,
2329        None => Capability::ReadOnly,
2330    };
2331
2332    let guard = session
2333        .db()
2334        .await
2335        .connections_for_buffer_update(project_id, session.connection_id, capability)
2336        .await?;
2337
2338    let (host, guests) = &*guard;
2339
2340    broadcast(
2341        Some(session.connection_id),
2342        guests.iter().chain([host]).copied(),
2343        |connection_id| {
2344            session
2345                .peer
2346                .forward_send(session.connection_id, connection_id, message.clone())
2347        },
2348    );
2349
2350    Ok(())
2351}
2352
2353/// Notify other participants that a project has been updated.
2354async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2355    request: T,
2356    session: Session,
2357) -> Result<()> {
2358    let project_id = ProjectId::from_proto(request.remote_entity_id());
2359    let project_connection_ids = session
2360        .db()
2361        .await
2362        .project_connection_ids(project_id, session.connection_id, false)
2363        .await?;
2364
2365    broadcast(
2366        Some(session.connection_id),
2367        project_connection_ids.iter().copied(),
2368        |connection_id| {
2369            session
2370                .peer
2371                .forward_send(session.connection_id, connection_id, request.clone())
2372        },
2373    );
2374    Ok(())
2375}
2376
2377/// Start following another user in a call.
2378async fn follow(
2379    request: proto::Follow,
2380    response: Response<proto::Follow>,
2381    session: Session,
2382) -> Result<()> {
2383    let room_id = RoomId::from_proto(request.room_id);
2384    let project_id = request.project_id.map(ProjectId::from_proto);
2385    let leader_id = request.leader_id.context("invalid leader id")?.into();
2386    let follower_id = session.connection_id;
2387
2388    session
2389        .db()
2390        .await
2391        .check_room_participants(room_id, leader_id, session.connection_id)
2392        .await?;
2393
2394    let response_payload = session
2395        .peer
2396        .forward_request(session.connection_id, leader_id, request)
2397        .await?;
2398    response.send(response_payload)?;
2399
2400    if let Some(project_id) = project_id {
2401        let room = session
2402            .db()
2403            .await
2404            .follow(room_id, project_id, leader_id, follower_id)
2405            .await?;
2406        room_updated(&room, &session.peer);
2407    }
2408
2409    Ok(())
2410}
2411
2412/// Stop following another user in a call.
2413async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2414    let room_id = RoomId::from_proto(request.room_id);
2415    let project_id = request.project_id.map(ProjectId::from_proto);
2416    let leader_id = request.leader_id.context("invalid leader id")?.into();
2417    let follower_id = session.connection_id;
2418
2419    session
2420        .db()
2421        .await
2422        .check_room_participants(room_id, leader_id, session.connection_id)
2423        .await?;
2424
2425    session
2426        .peer
2427        .forward_send(session.connection_id, leader_id, request)?;
2428
2429    if let Some(project_id) = project_id {
2430        let room = session
2431            .db()
2432            .await
2433            .unfollow(room_id, project_id, leader_id, follower_id)
2434            .await?;
2435        room_updated(&room, &session.peer);
2436    }
2437
2438    Ok(())
2439}
2440
2441/// Notify everyone following you of your current location.
2442async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2443    let room_id = RoomId::from_proto(request.room_id);
2444    let database = session.db.lock().await;
2445
2446    let connection_ids = if let Some(project_id) = request.project_id {
2447        let project_id = ProjectId::from_proto(project_id);
2448        database
2449            .project_connection_ids(project_id, session.connection_id, true)
2450            .await?
2451    } else {
2452        database
2453            .room_connection_ids(room_id, session.connection_id)
2454            .await?
2455    };
2456
2457    // For now, don't send view update messages back to that view's current leader.
2458    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2459        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2460        _ => None,
2461    });
2462
2463    for connection_id in connection_ids.iter().cloned() {
2464        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2465            session
2466                .peer
2467                .forward_send(session.connection_id, connection_id, request.clone())?;
2468        }
2469    }
2470    Ok(())
2471}
2472
2473/// Get public data about users.
2474async fn get_users(
2475    request: proto::GetUsers,
2476    response: Response<proto::GetUsers>,
2477    session: Session,
2478) -> Result<()> {
2479    let user_ids = request
2480        .user_ids
2481        .into_iter()
2482        .map(UserId::from_proto)
2483        .collect();
2484    let users = session
2485        .db()
2486        .await
2487        .get_users_by_ids(user_ids)
2488        .await?
2489        .into_iter()
2490        .map(|user| proto::User {
2491            id: user.id.to_proto(),
2492            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2493            github_login: user.github_login,
2494            email: user.email_address,
2495            name: user.name,
2496        })
2497        .collect();
2498    response.send(proto::UsersResponse { users })?;
2499    Ok(())
2500}
2501
2502/// Search for users (to invite) buy Github login
2503async fn fuzzy_search_users(
2504    request: proto::FuzzySearchUsers,
2505    response: Response<proto::FuzzySearchUsers>,
2506    session: Session,
2507) -> Result<()> {
2508    let query = request.query;
2509    let users = match query.len() {
2510        0 => vec![],
2511        1 | 2 => session
2512            .db()
2513            .await
2514            .get_user_by_github_login(&query)
2515            .await?
2516            .into_iter()
2517            .collect(),
2518        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2519    };
2520    let users = users
2521        .into_iter()
2522        .filter(|user| user.id != session.user_id())
2523        .map(|user| proto::User {
2524            id: user.id.to_proto(),
2525            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526            github_login: user.github_login,
2527            name: user.name,
2528            email: user.email_address,
2529        })
2530        .collect();
2531    response.send(proto::UsersResponse { users })?;
2532    Ok(())
2533}
2534
2535/// Send a contact request to another user.
2536async fn request_contact(
2537    request: proto::RequestContact,
2538    response: Response<proto::RequestContact>,
2539    session: Session,
2540) -> Result<()> {
2541    let requester_id = session.user_id();
2542    let responder_id = UserId::from_proto(request.responder_id);
2543    if requester_id == responder_id {
2544        return Err(anyhow!("cannot add yourself as a contact"))?;
2545    }
2546
2547    let notifications = session
2548        .db()
2549        .await
2550        .send_contact_request(requester_id, responder_id)
2551        .await?;
2552
2553    // Update outgoing contact requests of requester
2554    let mut update = proto::UpdateContacts::default();
2555    update.outgoing_requests.push(responder_id.to_proto());
2556    for connection_id in session
2557        .connection_pool()
2558        .await
2559        .user_connection_ids(requester_id)
2560    {
2561        session.peer.send(connection_id, update.clone())?;
2562    }
2563
2564    // Update incoming contact requests of responder
2565    let mut update = proto::UpdateContacts::default();
2566    update
2567        .incoming_requests
2568        .push(proto::IncomingContactRequest {
2569            requester_id: requester_id.to_proto(),
2570        });
2571    let connection_pool = session.connection_pool().await;
2572    for connection_id in connection_pool.user_connection_ids(responder_id) {
2573        session.peer.send(connection_id, update.clone())?;
2574    }
2575
2576    send_notifications(&connection_pool, &session.peer, notifications);
2577
2578    response.send(proto::Ack {})?;
2579    Ok(())
2580}
2581
2582/// Accept or decline a contact request
2583async fn respond_to_contact_request(
2584    request: proto::RespondToContactRequest,
2585    response: Response<proto::RespondToContactRequest>,
2586    session: Session,
2587) -> Result<()> {
2588    let responder_id = session.user_id();
2589    let requester_id = UserId::from_proto(request.requester_id);
2590    let db = session.db().await;
2591    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2592        db.dismiss_contact_notification(responder_id, requester_id)
2593            .await?;
2594    } else {
2595        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2596
2597        let notifications = db
2598            .respond_to_contact_request(responder_id, requester_id, accept)
2599            .await?;
2600        let requester_busy = db.is_user_busy(requester_id).await?;
2601        let responder_busy = db.is_user_busy(responder_id).await?;
2602
2603        let pool = session.connection_pool().await;
2604        // Update responder with new contact
2605        let mut update = proto::UpdateContacts::default();
2606        if accept {
2607            update
2608                .contacts
2609                .push(contact_for_user(requester_id, requester_busy, &pool));
2610        }
2611        update
2612            .remove_incoming_requests
2613            .push(requester_id.to_proto());
2614        for connection_id in pool.user_connection_ids(responder_id) {
2615            session.peer.send(connection_id, update.clone())?;
2616        }
2617
2618        // Update requester with new contact
2619        let mut update = proto::UpdateContacts::default();
2620        if accept {
2621            update
2622                .contacts
2623                .push(contact_for_user(responder_id, responder_busy, &pool));
2624        }
2625        update
2626            .remove_outgoing_requests
2627            .push(responder_id.to_proto());
2628
2629        for connection_id in pool.user_connection_ids(requester_id) {
2630            session.peer.send(connection_id, update.clone())?;
2631        }
2632
2633        send_notifications(&pool, &session.peer, notifications);
2634    }
2635
2636    response.send(proto::Ack {})?;
2637    Ok(())
2638}
2639
2640/// Remove a contact.
2641async fn remove_contact(
2642    request: proto::RemoveContact,
2643    response: Response<proto::RemoveContact>,
2644    session: Session,
2645) -> Result<()> {
2646    let requester_id = session.user_id();
2647    let responder_id = UserId::from_proto(request.user_id);
2648    let db = session.db().await;
2649    let (contact_accepted, deleted_notification_id) =
2650        db.remove_contact(requester_id, responder_id).await?;
2651
2652    let pool = session.connection_pool().await;
2653    // Update outgoing contact requests of requester
2654    let mut update = proto::UpdateContacts::default();
2655    if contact_accepted {
2656        update.remove_contacts.push(responder_id.to_proto());
2657    } else {
2658        update
2659            .remove_outgoing_requests
2660            .push(responder_id.to_proto());
2661    }
2662    for connection_id in pool.user_connection_ids(requester_id) {
2663        session.peer.send(connection_id, update.clone())?;
2664    }
2665
2666    // Update incoming contact requests of responder
2667    let mut update = proto::UpdateContacts::default();
2668    if contact_accepted {
2669        update.remove_contacts.push(requester_id.to_proto());
2670    } else {
2671        update
2672            .remove_incoming_requests
2673            .push(requester_id.to_proto());
2674    }
2675    for connection_id in pool.user_connection_ids(responder_id) {
2676        session.peer.send(connection_id, update.clone())?;
2677        if let Some(notification_id) = deleted_notification_id {
2678            session.peer.send(
2679                connection_id,
2680                proto::DeleteNotification {
2681                    notification_id: notification_id.to_proto(),
2682                },
2683            )?;
2684        }
2685    }
2686
2687    response.send(proto::Ack {})?;
2688    Ok(())
2689}
2690
2691fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2692    version.0.minor() < 139
2693}
2694
2695async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2696    if is_staff {
2697        return Ok(proto::Plan::ZedPro);
2698    }
2699
2700    let subscription = db.get_active_billing_subscription(user_id).await?;
2701    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2702
2703    let plan = if let Some(subscription_kind) = subscription_kind {
2704        match subscription_kind {
2705            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2706            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2707            SubscriptionKind::ZedFree => proto::Plan::Free,
2708        }
2709    } else {
2710        proto::Plan::Free
2711    };
2712
2713    Ok(plan)
2714}
2715
2716async fn make_update_user_plan_message(
2717    user: &User,
2718    is_staff: bool,
2719    db: &Arc<Database>,
2720    llm_db: Option<Arc<LlmDatabase>>,
2721) -> Result<proto::UpdateUserPlan> {
2722    let feature_flags = db.get_user_flags(user.id).await?;
2723    let plan = current_plan(db, user.id, is_staff).await?;
2724    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2725    let billing_preferences = db.get_billing_preferences(user.id).await?;
2726
2727    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2728        let subscription = db.get_active_billing_subscription(user.id).await?;
2729
2730        let subscription_period =
2731            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2732
2733        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2734            llm_db
2735                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2736                .await?
2737        } else {
2738            None
2739        };
2740
2741        (subscription_period, usage)
2742    } else {
2743        (None, None)
2744    };
2745
2746    let account_too_young =
2747        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2748
2749    Ok(proto::UpdateUserPlan {
2750        plan: plan.into(),
2751        trial_started_at: billing_customer
2752            .as_ref()
2753            .and_then(|billing_customer| billing_customer.trial_started_at)
2754            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2755        is_usage_based_billing_enabled: if is_staff {
2756            Some(true)
2757        } else {
2758            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2759        },
2760        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2761            proto::SubscriptionPeriod {
2762                started_at: started_at.timestamp() as u64,
2763                ended_at: ended_at.timestamp() as u64,
2764            }
2765        }),
2766        account_too_young: Some(account_too_young),
2767        has_overdue_invoices: billing_customer
2768            .map(|billing_customer| billing_customer.has_overdue_invoices),
2769        usage: usage.map(|usage| {
2770            let plan = match plan {
2771                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2772                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2773                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2774            };
2775
2776            let model_requests_limit = match plan.model_requests_limit() {
2777                zed_llm_client::UsageLimit::Limited(limit) => {
2778                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2779                        && feature_flags
2780                            .iter()
2781                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2782                    {
2783                        1_000
2784                    } else {
2785                        limit
2786                    };
2787
2788                    zed_llm_client::UsageLimit::Limited(limit)
2789                }
2790                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2791            };
2792
2793            proto::SubscriptionUsage {
2794                model_requests_usage_amount: usage.model_requests as u32,
2795                model_requests_usage_limit: Some(proto::UsageLimit {
2796                    variant: Some(match model_requests_limit {
2797                        zed_llm_client::UsageLimit::Limited(limit) => {
2798                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2799                                limit: limit as u32,
2800                            })
2801                        }
2802                        zed_llm_client::UsageLimit::Unlimited => {
2803                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2804                        }
2805                    }),
2806                }),
2807                edit_predictions_usage_amount: usage.edit_predictions as u32,
2808                edit_predictions_usage_limit: Some(proto::UsageLimit {
2809                    variant: Some(match plan.edit_predictions_limit() {
2810                        zed_llm_client::UsageLimit::Limited(limit) => {
2811                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2812                                limit: limit as u32,
2813                            })
2814                        }
2815                        zed_llm_client::UsageLimit::Unlimited => {
2816                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2817                        }
2818                    }),
2819                }),
2820            }
2821        }),
2822    })
2823}
2824
2825async fn update_user_plan(session: &Session) -> Result<()> {
2826    let db = session.db().await;
2827
2828    let update_user_plan = make_update_user_plan_message(
2829        session.principal.user(),
2830        session.is_staff(),
2831        &db.0,
2832        session.app_state.llm_db.clone(),
2833    )
2834    .await?;
2835
2836    session
2837        .peer
2838        .send(session.connection_id, update_user_plan)
2839        .trace_err();
2840
2841    Ok(())
2842}
2843
2844async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2845    subscribe_user_to_channels(session.user_id(), &session).await?;
2846    Ok(())
2847}
2848
2849async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2850    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2851    let mut pool = session.connection_pool().await;
2852    for membership in &channels_for_user.channel_memberships {
2853        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2854    }
2855    session.peer.send(
2856        session.connection_id,
2857        build_update_user_channels(&channels_for_user),
2858    )?;
2859    session.peer.send(
2860        session.connection_id,
2861        build_channels_update(channels_for_user),
2862    )?;
2863    Ok(())
2864}
2865
2866/// Creates a new channel.
2867async fn create_channel(
2868    request: proto::CreateChannel,
2869    response: Response<proto::CreateChannel>,
2870    session: Session,
2871) -> Result<()> {
2872    let db = session.db().await;
2873
2874    let parent_id = request.parent_id.map(ChannelId::from_proto);
2875    let (channel, membership) = db
2876        .create_channel(&request.name, parent_id, session.user_id())
2877        .await?;
2878
2879    let root_id = channel.root_id();
2880    let channel = Channel::from_model(channel);
2881
2882    response.send(proto::CreateChannelResponse {
2883        channel: Some(channel.to_proto()),
2884        parent_id: request.parent_id,
2885    })?;
2886
2887    let mut connection_pool = session.connection_pool().await;
2888    if let Some(membership) = membership {
2889        connection_pool.subscribe_to_channel(
2890            membership.user_id,
2891            membership.channel_id,
2892            membership.role,
2893        );
2894        let update = proto::UpdateUserChannels {
2895            channel_memberships: vec![proto::ChannelMembership {
2896                channel_id: membership.channel_id.to_proto(),
2897                role: membership.role.into(),
2898            }],
2899            ..Default::default()
2900        };
2901        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2902            session.peer.send(connection_id, update.clone())?;
2903        }
2904    }
2905
2906    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2907        if !role.can_see_channel(channel.visibility) {
2908            continue;
2909        }
2910
2911        let update = proto::UpdateChannels {
2912            channels: vec![channel.to_proto()],
2913            ..Default::default()
2914        };
2915        session.peer.send(connection_id, update.clone())?;
2916    }
2917
2918    Ok(())
2919}
2920
2921/// Delete a channel
2922async fn delete_channel(
2923    request: proto::DeleteChannel,
2924    response: Response<proto::DeleteChannel>,
2925    session: Session,
2926) -> Result<()> {
2927    let db = session.db().await;
2928
2929    let channel_id = request.channel_id;
2930    let (root_channel, removed_channels) = db
2931        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2932        .await?;
2933    response.send(proto::Ack {})?;
2934
2935    // Notify members of removed channels
2936    let mut update = proto::UpdateChannels::default();
2937    update
2938        .delete_channels
2939        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2940
2941    let connection_pool = session.connection_pool().await;
2942    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2943        session.peer.send(connection_id, update.clone())?;
2944    }
2945
2946    Ok(())
2947}
2948
2949/// Invite someone to join a channel.
2950async fn invite_channel_member(
2951    request: proto::InviteChannelMember,
2952    response: Response<proto::InviteChannelMember>,
2953    session: Session,
2954) -> Result<()> {
2955    let db = session.db().await;
2956    let channel_id = ChannelId::from_proto(request.channel_id);
2957    let invitee_id = UserId::from_proto(request.user_id);
2958    let InviteMemberResult {
2959        channel,
2960        notifications,
2961    } = db
2962        .invite_channel_member(
2963            channel_id,
2964            invitee_id,
2965            session.user_id(),
2966            request.role().into(),
2967        )
2968        .await?;
2969
2970    let update = proto::UpdateChannels {
2971        channel_invitations: vec![channel.to_proto()],
2972        ..Default::default()
2973    };
2974
2975    let connection_pool = session.connection_pool().await;
2976    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2977        session.peer.send(connection_id, update.clone())?;
2978    }
2979
2980    send_notifications(&connection_pool, &session.peer, notifications);
2981
2982    response.send(proto::Ack {})?;
2983    Ok(())
2984}
2985
2986/// remove someone from a channel
2987async fn remove_channel_member(
2988    request: proto::RemoveChannelMember,
2989    response: Response<proto::RemoveChannelMember>,
2990    session: Session,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let member_id = UserId::from_proto(request.user_id);
2995
2996    let RemoveChannelMemberResult {
2997        membership_update,
2998        notification_id,
2999    } = db
3000        .remove_channel_member(channel_id, member_id, session.user_id())
3001        .await?;
3002
3003    let mut connection_pool = session.connection_pool().await;
3004    notify_membership_updated(
3005        &mut connection_pool,
3006        membership_update,
3007        member_id,
3008        &session.peer,
3009    );
3010    for connection_id in connection_pool.user_connection_ids(member_id) {
3011        if let Some(notification_id) = notification_id {
3012            session
3013                .peer
3014                .send(
3015                    connection_id,
3016                    proto::DeleteNotification {
3017                        notification_id: notification_id.to_proto(),
3018                    },
3019                )
3020                .trace_err();
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Toggle the channel between public and private.
3029/// Care is taken to maintain the invariant that public channels only descend from public channels,
3030/// (though members-only channels can appear at any point in the hierarchy).
3031async fn set_channel_visibility(
3032    request: proto::SetChannelVisibility,
3033    response: Response<proto::SetChannelVisibility>,
3034    session: Session,
3035) -> Result<()> {
3036    let db = session.db().await;
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038    let visibility = request.visibility().into();
3039
3040    let channel_model = db
3041        .set_channel_visibility(channel_id, visibility, session.user_id())
3042        .await?;
3043    let root_id = channel_model.root_id();
3044    let channel = Channel::from_model(channel_model);
3045
3046    let mut connection_pool = session.connection_pool().await;
3047    for (user_id, role) in connection_pool
3048        .channel_user_ids(root_id)
3049        .collect::<Vec<_>>()
3050        .into_iter()
3051    {
3052        let update = if role.can_see_channel(channel.visibility) {
3053            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3054            proto::UpdateChannels {
3055                channels: vec![channel.to_proto()],
3056                ..Default::default()
3057            }
3058        } else {
3059            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3060            proto::UpdateChannels {
3061                delete_channels: vec![channel.id.to_proto()],
3062                ..Default::default()
3063            }
3064        };
3065
3066        for connection_id in connection_pool.user_connection_ids(user_id) {
3067            session.peer.send(connection_id, update.clone())?;
3068        }
3069    }
3070
3071    response.send(proto::Ack {})?;
3072    Ok(())
3073}
3074
3075/// Alter the role for a user in the channel.
3076async fn set_channel_member_role(
3077    request: proto::SetChannelMemberRole,
3078    response: Response<proto::SetChannelMemberRole>,
3079    session: Session,
3080) -> Result<()> {
3081    let db = session.db().await;
3082    let channel_id = ChannelId::from_proto(request.channel_id);
3083    let member_id = UserId::from_proto(request.user_id);
3084    let result = db
3085        .set_channel_member_role(
3086            channel_id,
3087            session.user_id(),
3088            member_id,
3089            request.role().into(),
3090        )
3091        .await?;
3092
3093    match result {
3094        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3095            let mut connection_pool = session.connection_pool().await;
3096            notify_membership_updated(
3097                &mut connection_pool,
3098                membership_update,
3099                member_id,
3100                &session.peer,
3101            )
3102        }
3103        db::SetMemberRoleResult::InviteUpdated(channel) => {
3104            let update = proto::UpdateChannels {
3105                channel_invitations: vec![channel.to_proto()],
3106                ..Default::default()
3107            };
3108
3109            for connection_id in session
3110                .connection_pool()
3111                .await
3112                .user_connection_ids(member_id)
3113            {
3114                session.peer.send(connection_id, update.clone())?;
3115            }
3116        }
3117    }
3118
3119    response.send(proto::Ack {})?;
3120    Ok(())
3121}
3122
3123/// Change the name of a channel
3124async fn rename_channel(
3125    request: proto::RenameChannel,
3126    response: Response<proto::RenameChannel>,
3127    session: Session,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let channel_model = db
3132        .rename_channel(channel_id, session.user_id(), &request.name)
3133        .await?;
3134    let root_id = channel_model.root_id();
3135    let channel = Channel::from_model(channel_model);
3136
3137    response.send(proto::RenameChannelResponse {
3138        channel: Some(channel.to_proto()),
3139    })?;
3140
3141    let connection_pool = session.connection_pool().await;
3142    let update = proto::UpdateChannels {
3143        channels: vec![channel.to_proto()],
3144        ..Default::default()
3145    };
3146    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147        if role.can_see_channel(channel.visibility) {
3148            session.peer.send(connection_id, update.clone())?;
3149        }
3150    }
3151
3152    Ok(())
3153}
3154
3155/// Move a channel to a new parent.
3156async fn move_channel(
3157    request: proto::MoveChannel,
3158    response: Response<proto::MoveChannel>,
3159    session: Session,
3160) -> Result<()> {
3161    let channel_id = ChannelId::from_proto(request.channel_id);
3162    let to = ChannelId::from_proto(request.to);
3163
3164    let (root_id, channels) = session
3165        .db()
3166        .await
3167        .move_channel(channel_id, to, session.user_id())
3168        .await?;
3169
3170    let connection_pool = session.connection_pool().await;
3171    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172        let channels = channels
3173            .iter()
3174            .filter_map(|channel| {
3175                if role.can_see_channel(channel.visibility) {
3176                    Some(channel.to_proto())
3177                } else {
3178                    None
3179                }
3180            })
3181            .collect::<Vec<_>>();
3182        if channels.is_empty() {
3183            continue;
3184        }
3185
3186        let update = proto::UpdateChannels {
3187            channels,
3188            ..Default::default()
3189        };
3190
3191        session.peer.send(connection_id, update.clone())?;
3192    }
3193
3194    response.send(Ack {})?;
3195    Ok(())
3196}
3197
3198/// Get the list of channel members
3199async fn get_channel_members(
3200    request: proto::GetChannelMembers,
3201    response: Response<proto::GetChannelMembers>,
3202    session: Session,
3203) -> Result<()> {
3204    let db = session.db().await;
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206    let limit = if request.limit == 0 {
3207        u16::MAX as u64
3208    } else {
3209        request.limit
3210    };
3211    let (members, users) = db
3212        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3213        .await?;
3214    response.send(proto::GetChannelMembersResponse { members, users })?;
3215    Ok(())
3216}
3217
3218/// Accept or decline a channel invitation.
3219async fn respond_to_channel_invite(
3220    request: proto::RespondToChannelInvite,
3221    response: Response<proto::RespondToChannelInvite>,
3222    session: Session,
3223) -> Result<()> {
3224    let db = session.db().await;
3225    let channel_id = ChannelId::from_proto(request.channel_id);
3226    let RespondToChannelInvite {
3227        membership_update,
3228        notifications,
3229    } = db
3230        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3231        .await?;
3232
3233    let mut connection_pool = session.connection_pool().await;
3234    if let Some(membership_update) = membership_update {
3235        notify_membership_updated(
3236            &mut connection_pool,
3237            membership_update,
3238            session.user_id(),
3239            &session.peer,
3240        );
3241    } else {
3242        let update = proto::UpdateChannels {
3243            remove_channel_invitations: vec![channel_id.to_proto()],
3244            ..Default::default()
3245        };
3246
3247        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3248            session.peer.send(connection_id, update.clone())?;
3249        }
3250    };
3251
3252    send_notifications(&connection_pool, &session.peer, notifications);
3253
3254    response.send(proto::Ack {})?;
3255
3256    Ok(())
3257}
3258
3259/// Join the channels' room
3260async fn join_channel(
3261    request: proto::JoinChannel,
3262    response: Response<proto::JoinChannel>,
3263    session: Session,
3264) -> Result<()> {
3265    let channel_id = ChannelId::from_proto(request.channel_id);
3266    join_channel_internal(channel_id, Box::new(response), session).await
3267}
3268
3269trait JoinChannelInternalResponse {
3270    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3271}
3272impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3273    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3274        Response::<proto::JoinChannel>::send(self, result)
3275    }
3276}
3277impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3278    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3279        Response::<proto::JoinRoom>::send(self, result)
3280    }
3281}
3282
3283async fn join_channel_internal(
3284    channel_id: ChannelId,
3285    response: Box<impl JoinChannelInternalResponse>,
3286    session: Session,
3287) -> Result<()> {
3288    let joined_room = {
3289        let mut db = session.db().await;
3290        // If zed quits without leaving the room, and the user re-opens zed before the
3291        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3292        // room they were in.
3293        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3294            tracing::info!(
3295                stale_connection_id = %connection,
3296                "cleaning up stale connection",
3297            );
3298            drop(db);
3299            leave_room_for_session(&session, connection).await?;
3300            db = session.db().await;
3301        }
3302
3303        let (joined_room, membership_updated, role) = db
3304            .join_channel(channel_id, session.user_id(), session.connection_id)
3305            .await?;
3306
3307        let live_kit_connection_info =
3308            session
3309                .app_state
3310                .livekit_client
3311                .as_ref()
3312                .and_then(|live_kit| {
3313                    let (can_publish, token) = if role == ChannelRole::Guest {
3314                        (
3315                            false,
3316                            live_kit
3317                                .guest_token(
3318                                    &joined_room.room.livekit_room,
3319                                    &session.user_id().to_string(),
3320                                )
3321                                .trace_err()?,
3322                        )
3323                    } else {
3324                        (
3325                            true,
3326                            live_kit
3327                                .room_token(
3328                                    &joined_room.room.livekit_room,
3329                                    &session.user_id().to_string(),
3330                                )
3331                                .trace_err()?,
3332                        )
3333                    };
3334
3335                    Some(LiveKitConnectionInfo {
3336                        server_url: live_kit.url().into(),
3337                        token,
3338                        can_publish,
3339                    })
3340                });
3341
3342        response.send(proto::JoinRoomResponse {
3343            room: Some(joined_room.room.clone()),
3344            channel_id: joined_room
3345                .channel
3346                .as_ref()
3347                .map(|channel| channel.id.to_proto()),
3348            live_kit_connection_info,
3349        })?;
3350
3351        let mut connection_pool = session.connection_pool().await;
3352        if let Some(membership_updated) = membership_updated {
3353            notify_membership_updated(
3354                &mut connection_pool,
3355                membership_updated,
3356                session.user_id(),
3357                &session.peer,
3358            );
3359        }
3360
3361        room_updated(&joined_room.room, &session.peer);
3362
3363        joined_room
3364    };
3365
3366    channel_updated(
3367        &joined_room.channel.context("channel not returned")?,
3368        &joined_room.room,
3369        &session.peer,
3370        &*session.connection_pool().await,
3371    );
3372
3373    update_user_contacts(session.user_id(), &session).await?;
3374    Ok(())
3375}
3376
3377/// Start editing the channel notes
3378async fn join_channel_buffer(
3379    request: proto::JoinChannelBuffer,
3380    response: Response<proto::JoinChannelBuffer>,
3381    session: Session,
3382) -> Result<()> {
3383    let db = session.db().await;
3384    let channel_id = ChannelId::from_proto(request.channel_id);
3385
3386    let open_response = db
3387        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3388        .await?;
3389
3390    let collaborators = open_response.collaborators.clone();
3391    response.send(open_response)?;
3392
3393    let update = UpdateChannelBufferCollaborators {
3394        channel_id: channel_id.to_proto(),
3395        collaborators: collaborators.clone(),
3396    };
3397    channel_buffer_updated(
3398        session.connection_id,
3399        collaborators
3400            .iter()
3401            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3402        &update,
3403        &session.peer,
3404    );
3405
3406    Ok(())
3407}
3408
3409/// Edit the channel notes
3410async fn update_channel_buffer(
3411    request: proto::UpdateChannelBuffer,
3412    session: Session,
3413) -> Result<()> {
3414    let db = session.db().await;
3415    let channel_id = ChannelId::from_proto(request.channel_id);
3416
3417    let (collaborators, epoch, version) = db
3418        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3419        .await?;
3420
3421    channel_buffer_updated(
3422        session.connection_id,
3423        collaborators.clone(),
3424        &proto::UpdateChannelBuffer {
3425            channel_id: channel_id.to_proto(),
3426            operations: request.operations,
3427        },
3428        &session.peer,
3429    );
3430
3431    let pool = &*session.connection_pool().await;
3432
3433    let non_collaborators =
3434        pool.channel_connection_ids(channel_id)
3435            .filter_map(|(connection_id, _)| {
3436                if collaborators.contains(&connection_id) {
3437                    None
3438                } else {
3439                    Some(connection_id)
3440                }
3441            });
3442
3443    broadcast(None, non_collaborators, |peer_id| {
3444        session.peer.send(
3445            peer_id,
3446            proto::UpdateChannels {
3447                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3448                    channel_id: channel_id.to_proto(),
3449                    epoch: epoch as u64,
3450                    version: version.clone(),
3451                }],
3452                ..Default::default()
3453            },
3454        )
3455    });
3456
3457    Ok(())
3458}
3459
3460/// Rejoin the channel notes after a connection blip
3461async fn rejoin_channel_buffers(
3462    request: proto::RejoinChannelBuffers,
3463    response: Response<proto::RejoinChannelBuffers>,
3464    session: Session,
3465) -> Result<()> {
3466    let db = session.db().await;
3467    let buffers = db
3468        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3469        .await?;
3470
3471    for rejoined_buffer in &buffers {
3472        let collaborators_to_notify = rejoined_buffer
3473            .buffer
3474            .collaborators
3475            .iter()
3476            .filter_map(|c| Some(c.peer_id?.into()));
3477        channel_buffer_updated(
3478            session.connection_id,
3479            collaborators_to_notify,
3480            &proto::UpdateChannelBufferCollaborators {
3481                channel_id: rejoined_buffer.buffer.channel_id,
3482                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3483            },
3484            &session.peer,
3485        );
3486    }
3487
3488    response.send(proto::RejoinChannelBuffersResponse {
3489        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3490    })?;
3491
3492    Ok(())
3493}
3494
3495/// Stop editing the channel notes
3496async fn leave_channel_buffer(
3497    request: proto::LeaveChannelBuffer,
3498    response: Response<proto::LeaveChannelBuffer>,
3499    session: Session,
3500) -> Result<()> {
3501    let db = session.db().await;
3502    let channel_id = ChannelId::from_proto(request.channel_id);
3503
3504    let left_buffer = db
3505        .leave_channel_buffer(channel_id, session.connection_id)
3506        .await?;
3507
3508    response.send(Ack {})?;
3509
3510    channel_buffer_updated(
3511        session.connection_id,
3512        left_buffer.connections,
3513        &proto::UpdateChannelBufferCollaborators {
3514            channel_id: channel_id.to_proto(),
3515            collaborators: left_buffer.collaborators,
3516        },
3517        &session.peer,
3518    );
3519
3520    Ok(())
3521}
3522
3523fn channel_buffer_updated<T: EnvelopedMessage>(
3524    sender_id: ConnectionId,
3525    collaborators: impl IntoIterator<Item = ConnectionId>,
3526    message: &T,
3527    peer: &Peer,
3528) {
3529    broadcast(Some(sender_id), collaborators, |peer_id| {
3530        peer.send(peer_id, message.clone())
3531    });
3532}
3533
3534fn send_notifications(
3535    connection_pool: &ConnectionPool,
3536    peer: &Peer,
3537    notifications: db::NotificationBatch,
3538) {
3539    for (user_id, notification) in notifications {
3540        for connection_id in connection_pool.user_connection_ids(user_id) {
3541            if let Err(error) = peer.send(
3542                connection_id,
3543                proto::AddNotification {
3544                    notification: Some(notification.clone()),
3545                },
3546            ) {
3547                tracing::error!(
3548                    "failed to send notification to {:?} {}",
3549                    connection_id,
3550                    error
3551                );
3552            }
3553        }
3554    }
3555}
3556
3557/// Send a message to the channel
3558async fn send_channel_message(
3559    request: proto::SendChannelMessage,
3560    response: Response<proto::SendChannelMessage>,
3561    session: Session,
3562) -> Result<()> {
3563    // Validate the message body.
3564    let body = request.body.trim().to_string();
3565    if body.len() > MAX_MESSAGE_LEN {
3566        return Err(anyhow!("message is too long"))?;
3567    }
3568    if body.is_empty() {
3569        return Err(anyhow!("message can't be blank"))?;
3570    }
3571
3572    // TODO: adjust mentions if body is trimmed
3573
3574    let timestamp = OffsetDateTime::now_utc();
3575    let nonce = request.nonce.context("nonce can't be blank")?;
3576
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578    let CreatedChannelMessage {
3579        message_id,
3580        participant_connection_ids,
3581        notifications,
3582    } = session
3583        .db()
3584        .await
3585        .create_channel_message(
3586            channel_id,
3587            session.user_id(),
3588            &body,
3589            &request.mentions,
3590            timestamp,
3591            nonce.clone().into(),
3592            request.reply_to_message_id.map(MessageId::from_proto),
3593        )
3594        .await?;
3595
3596    let message = proto::ChannelMessage {
3597        sender_id: session.user_id().to_proto(),
3598        id: message_id.to_proto(),
3599        body,
3600        mentions: request.mentions,
3601        timestamp: timestamp.unix_timestamp() as u64,
3602        nonce: Some(nonce),
3603        reply_to_message_id: request.reply_to_message_id,
3604        edited_at: None,
3605    };
3606    broadcast(
3607        Some(session.connection_id),
3608        participant_connection_ids.clone(),
3609        |connection| {
3610            session.peer.send(
3611                connection,
3612                proto::ChannelMessageSent {
3613                    channel_id: channel_id.to_proto(),
3614                    message: Some(message.clone()),
3615                },
3616            )
3617        },
3618    );
3619    response.send(proto::SendChannelMessageResponse {
3620        message: Some(message),
3621    })?;
3622
3623    let pool = &*session.connection_pool().await;
3624    let non_participants =
3625        pool.channel_connection_ids(channel_id)
3626            .filter_map(|(connection_id, _)| {
3627                if participant_connection_ids.contains(&connection_id) {
3628                    None
3629                } else {
3630                    Some(connection_id)
3631                }
3632            });
3633    broadcast(None, non_participants, |peer_id| {
3634        session.peer.send(
3635            peer_id,
3636            proto::UpdateChannels {
3637                latest_channel_message_ids: vec![proto::ChannelMessageId {
3638                    channel_id: channel_id.to_proto(),
3639                    message_id: message_id.to_proto(),
3640                }],
3641                ..Default::default()
3642            },
3643        )
3644    });
3645    send_notifications(pool, &session.peer, notifications);
3646
3647    Ok(())
3648}
3649
3650/// Delete a channel message
3651async fn remove_channel_message(
3652    request: proto::RemoveChannelMessage,
3653    response: Response<proto::RemoveChannelMessage>,
3654    session: Session,
3655) -> Result<()> {
3656    let channel_id = ChannelId::from_proto(request.channel_id);
3657    let message_id = MessageId::from_proto(request.message_id);
3658    let (connection_ids, existing_notification_ids) = session
3659        .db()
3660        .await
3661        .remove_channel_message(channel_id, message_id, session.user_id())
3662        .await?;
3663
3664    broadcast(
3665        Some(session.connection_id),
3666        connection_ids,
3667        move |connection| {
3668            session.peer.send(connection, request.clone())?;
3669
3670            for notification_id in &existing_notification_ids {
3671                session.peer.send(
3672                    connection,
3673                    proto::DeleteNotification {
3674                        notification_id: (*notification_id).to_proto(),
3675                    },
3676                )?;
3677            }
3678
3679            Ok(())
3680        },
3681    );
3682    response.send(proto::Ack {})?;
3683    Ok(())
3684}
3685
3686async fn update_channel_message(
3687    request: proto::UpdateChannelMessage,
3688    response: Response<proto::UpdateChannelMessage>,
3689    session: Session,
3690) -> Result<()> {
3691    let channel_id = ChannelId::from_proto(request.channel_id);
3692    let message_id = MessageId::from_proto(request.message_id);
3693    let updated_at = OffsetDateTime::now_utc();
3694    let UpdatedChannelMessage {
3695        message_id,
3696        participant_connection_ids,
3697        notifications,
3698        reply_to_message_id,
3699        timestamp,
3700        deleted_mention_notification_ids,
3701        updated_mention_notifications,
3702    } = session
3703        .db()
3704        .await
3705        .update_channel_message(
3706            channel_id,
3707            message_id,
3708            session.user_id(),
3709            request.body.as_str(),
3710            &request.mentions,
3711            updated_at,
3712        )
3713        .await?;
3714
3715    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3716
3717    let message = proto::ChannelMessage {
3718        sender_id: session.user_id().to_proto(),
3719        id: message_id.to_proto(),
3720        body: request.body.clone(),
3721        mentions: request.mentions.clone(),
3722        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3723        nonce: Some(nonce),
3724        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3725        edited_at: Some(updated_at.unix_timestamp() as u64),
3726    };
3727
3728    response.send(proto::Ack {})?;
3729
3730    let pool = &*session.connection_pool().await;
3731    broadcast(
3732        Some(session.connection_id),
3733        participant_connection_ids,
3734        |connection| {
3735            session.peer.send(
3736                connection,
3737                proto::ChannelMessageUpdate {
3738                    channel_id: channel_id.to_proto(),
3739                    message: Some(message.clone()),
3740                },
3741            )?;
3742
3743            for notification_id in &deleted_mention_notification_ids {
3744                session.peer.send(
3745                    connection,
3746                    proto::DeleteNotification {
3747                        notification_id: (*notification_id).to_proto(),
3748                    },
3749                )?;
3750            }
3751
3752            for notification in &updated_mention_notifications {
3753                session.peer.send(
3754                    connection,
3755                    proto::UpdateNotification {
3756                        notification: Some(notification.clone()),
3757                    },
3758                )?;
3759            }
3760
3761            Ok(())
3762        },
3763    );
3764
3765    send_notifications(pool, &session.peer, notifications);
3766
3767    Ok(())
3768}
3769
3770/// Mark a channel message as read
3771async fn acknowledge_channel_message(
3772    request: proto::AckChannelMessage,
3773    session: Session,
3774) -> Result<()> {
3775    let channel_id = ChannelId::from_proto(request.channel_id);
3776    let message_id = MessageId::from_proto(request.message_id);
3777    let notifications = session
3778        .db()
3779        .await
3780        .observe_channel_message(channel_id, session.user_id(), message_id)
3781        .await?;
3782    send_notifications(
3783        &*session.connection_pool().await,
3784        &session.peer,
3785        notifications,
3786    );
3787    Ok(())
3788}
3789
3790/// Mark a buffer version as synced
3791async fn acknowledge_buffer_version(
3792    request: proto::AckBufferOperation,
3793    session: Session,
3794) -> Result<()> {
3795    let buffer_id = BufferId::from_proto(request.buffer_id);
3796    session
3797        .db()
3798        .await
3799        .observe_buffer_version(
3800            buffer_id,
3801            session.user_id(),
3802            request.epoch as i32,
3803            &request.version,
3804        )
3805        .await?;
3806    Ok(())
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811    _request: proto::GetSupermavenApiKey,
3812    response: Response<proto::GetSupermavenApiKey>,
3813    session: Session,
3814) -> Result<()> {
3815    let user_id: String = session.user_id().to_string();
3816    if !session.is_staff() {
3817        return Err(anyhow!("supermaven not enabled for this account"))?;
3818    }
3819
3820    let email = session.email().context("user must have an email")?;
3821
3822    let supermaven_admin_api = session
3823        .supermaven_client
3824        .as_ref()
3825        .context("supermaven not configured")?;
3826
3827    let result = supermaven_admin_api
3828        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3829        .await?;
3830
3831    response.send(proto::GetSupermavenApiKeyResponse {
3832        api_key: result.api_key,
3833    })?;
3834
3835    Ok(())
3836}
3837
3838/// Start receiving chat updates for a channel
3839async fn join_channel_chat(
3840    request: proto::JoinChannelChat,
3841    response: Response<proto::JoinChannelChat>,
3842    session: Session,
3843) -> Result<()> {
3844    let channel_id = ChannelId::from_proto(request.channel_id);
3845
3846    let db = session.db().await;
3847    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3848        .await?;
3849    let messages = db
3850        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3851        .await?;
3852    response.send(proto::JoinChannelChatResponse {
3853        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3854        messages,
3855    })?;
3856    Ok(())
3857}
3858
3859/// Stop receiving chat updates for a channel
3860async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3861    let channel_id = ChannelId::from_proto(request.channel_id);
3862    session
3863        .db()
3864        .await
3865        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3866        .await?;
3867    Ok(())
3868}
3869
3870/// Retrieve the chat history for a channel
3871async fn get_channel_messages(
3872    request: proto::GetChannelMessages,
3873    response: Response<proto::GetChannelMessages>,
3874    session: Session,
3875) -> Result<()> {
3876    let channel_id = ChannelId::from_proto(request.channel_id);
3877    let messages = session
3878        .db()
3879        .await
3880        .get_channel_messages(
3881            channel_id,
3882            session.user_id(),
3883            MESSAGE_COUNT_PER_PAGE,
3884            Some(MessageId::from_proto(request.before_message_id)),
3885        )
3886        .await?;
3887    response.send(proto::GetChannelMessagesResponse {
3888        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3889        messages,
3890    })?;
3891    Ok(())
3892}
3893
3894/// Retrieve specific chat messages
3895async fn get_channel_messages_by_id(
3896    request: proto::GetChannelMessagesById,
3897    response: Response<proto::GetChannelMessagesById>,
3898    session: Session,
3899) -> Result<()> {
3900    let message_ids = request
3901        .message_ids
3902        .iter()
3903        .map(|id| MessageId::from_proto(*id))
3904        .collect::<Vec<_>>();
3905    let messages = session
3906        .db()
3907        .await
3908        .get_channel_messages_by_id(session.user_id(), &message_ids)
3909        .await?;
3910    response.send(proto::GetChannelMessagesResponse {
3911        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3912        messages,
3913    })?;
3914    Ok(())
3915}
3916
3917/// Retrieve the current users notifications
3918async fn get_notifications(
3919    request: proto::GetNotifications,
3920    response: Response<proto::GetNotifications>,
3921    session: Session,
3922) -> Result<()> {
3923    let notifications = session
3924        .db()
3925        .await
3926        .get_notifications(
3927            session.user_id(),
3928            NOTIFICATION_COUNT_PER_PAGE,
3929            request.before_id.map(db::NotificationId::from_proto),
3930        )
3931        .await?;
3932    response.send(proto::GetNotificationsResponse {
3933        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3934        notifications,
3935    })?;
3936    Ok(())
3937}
3938
3939/// Mark notifications as read
3940async fn mark_notification_as_read(
3941    request: proto::MarkNotificationRead,
3942    response: Response<proto::MarkNotificationRead>,
3943    session: Session,
3944) -> Result<()> {
3945    let database = &session.db().await;
3946    let notifications = database
3947        .mark_notification_as_read_by_id(
3948            session.user_id(),
3949            NotificationId::from_proto(request.notification_id),
3950        )
3951        .await?;
3952    send_notifications(
3953        &*session.connection_pool().await,
3954        &session.peer,
3955        notifications,
3956    );
3957    response.send(proto::Ack {})?;
3958    Ok(())
3959}
3960
3961/// Get the current users information
3962async fn get_private_user_info(
3963    _request: proto::GetPrivateUserInfo,
3964    response: Response<proto::GetPrivateUserInfo>,
3965    session: Session,
3966) -> Result<()> {
3967    let db = session.db().await;
3968
3969    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3970    let user = db
3971        .get_user_by_id(session.user_id())
3972        .await?
3973        .context("user not found")?;
3974    let flags = db.get_user_flags(session.user_id()).await?;
3975
3976    response.send(proto::GetPrivateUserInfoResponse {
3977        metrics_id,
3978        staff: user.admin,
3979        flags,
3980        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3981    })?;
3982    Ok(())
3983}
3984
3985/// Accept the terms of service (tos) on behalf of the current user
3986async fn accept_terms_of_service(
3987    _request: proto::AcceptTermsOfService,
3988    response: Response<proto::AcceptTermsOfService>,
3989    session: Session,
3990) -> Result<()> {
3991    let db = session.db().await;
3992
3993    let accepted_tos_at = Utc::now();
3994    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3995        .await?;
3996
3997    response.send(proto::AcceptTermsOfServiceResponse {
3998        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3999    })?;
4000    Ok(())
4001}
4002
4003/// The minimum account age an account must have in order to use the LLM service.
4004pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4005
4006async fn get_llm_api_token(
4007    _request: proto::GetLlmToken,
4008    response: Response<proto::GetLlmToken>,
4009    session: Session,
4010) -> Result<()> {
4011    let db = session.db().await;
4012
4013    let flags = db.get_user_flags(session.user_id()).await?;
4014
4015    let user_id = session.user_id();
4016    let user = db
4017        .get_user_by_id(user_id)
4018        .await?
4019        .with_context(|| format!("user {user_id} not found"))?;
4020
4021    if user.accepted_tos_at.is_none() {
4022        Err(anyhow!("terms of service not accepted"))?
4023    }
4024
4025    let stripe_client = session
4026        .app_state
4027        .stripe_client
4028        .as_ref()
4029        .context("failed to retrieve Stripe client")?;
4030
4031    let stripe_billing = session
4032        .app_state
4033        .stripe_billing
4034        .as_ref()
4035        .context("failed to retrieve Stripe billing object")?;
4036
4037    let billing_customer = if let Some(billing_customer) =
4038        db.get_billing_customer_by_user_id(user.id).await?
4039    {
4040        billing_customer
4041    } else {
4042        let customer_id = stripe_billing
4043            .find_or_create_customer_by_email(user.email_address.as_deref())
4044            .await?;
4045
4046        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4047            .await?
4048            .context("billing customer not found")?
4049    };
4050
4051    let billing_subscription =
4052        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4053            billing_subscription
4054        } else {
4055            let stripe_customer_id =
4056                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4057
4058            let stripe_subscription = stripe_billing
4059                .subscribe_to_zed_free(stripe_customer_id)
4060                .await?;
4061
4062            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4063                billing_customer_id: billing_customer.id,
4064                kind: Some(SubscriptionKind::ZedFree),
4065                stripe_subscription_id: stripe_subscription.id.to_string(),
4066                stripe_subscription_status: stripe_subscription.status.into(),
4067                stripe_cancellation_reason: None,
4068                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4069                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4070            })
4071            .await?
4072        };
4073
4074    let billing_preferences = db.get_billing_preferences(user.id).await?;
4075
4076    let token = LlmTokenClaims::create(
4077        &user,
4078        session.is_staff(),
4079        billing_customer,
4080        billing_preferences,
4081        &flags,
4082        billing_subscription,
4083        session.system_id.clone(),
4084        &session.app_state.config,
4085    )?;
4086    response.send(proto::GetLlmTokenResponse { token })?;
4087    Ok(())
4088}
4089
4090fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4091    let message = match message {
4092        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4093        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4094        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4095        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4096        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4097            code: frame.code.into(),
4098            reason: frame.reason.as_str().to_owned().into(),
4099        })),
4100        // We should never receive a frame while reading the message, according
4101        // to the `tungstenite` maintainers:
4102        //
4103        // > It cannot occur when you read messages from the WebSocket, but it
4104        // > can be used when you want to send the raw frames (e.g. you want to
4105        // > send the frames to the WebSocket without composing the full message first).
4106        // >
4107        // > — https://github.com/snapview/tungstenite-rs/issues/268
4108        TungsteniteMessage::Frame(_) => {
4109            bail!("received an unexpected frame while reading the message")
4110        }
4111    };
4112
4113    Ok(message)
4114}
4115
4116fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4117    match message {
4118        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4119        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4120        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4121        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4122        AxumMessage::Close(frame) => {
4123            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4124                code: frame.code.into(),
4125                reason: frame.reason.as_ref().into(),
4126            }))
4127        }
4128    }
4129}
4130
4131fn notify_membership_updated(
4132    connection_pool: &mut ConnectionPool,
4133    result: MembershipUpdated,
4134    user_id: UserId,
4135    peer: &Peer,
4136) {
4137    for membership in &result.new_channels.channel_memberships {
4138        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4139    }
4140    for channel_id in &result.removed_channels {
4141        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4142    }
4143
4144    let user_channels_update = proto::UpdateUserChannels {
4145        channel_memberships: result
4146            .new_channels
4147            .channel_memberships
4148            .iter()
4149            .map(|cm| proto::ChannelMembership {
4150                channel_id: cm.channel_id.to_proto(),
4151                role: cm.role.into(),
4152            })
4153            .collect(),
4154        ..Default::default()
4155    };
4156
4157    let mut update = build_channels_update(result.new_channels);
4158    update.delete_channels = result
4159        .removed_channels
4160        .into_iter()
4161        .map(|id| id.to_proto())
4162        .collect();
4163    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4164
4165    for connection_id in connection_pool.user_connection_ids(user_id) {
4166        peer.send(connection_id, user_channels_update.clone())
4167            .trace_err();
4168        peer.send(connection_id, update.clone()).trace_err();
4169    }
4170}
4171
4172fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4173    proto::UpdateUserChannels {
4174        channel_memberships: channels
4175            .channel_memberships
4176            .iter()
4177            .map(|m| proto::ChannelMembership {
4178                channel_id: m.channel_id.to_proto(),
4179                role: m.role.into(),
4180            })
4181            .collect(),
4182        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4183        observed_channel_message_id: channels.observed_channel_messages.clone(),
4184    }
4185}
4186
4187fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4188    let mut update = proto::UpdateChannels::default();
4189
4190    for channel in channels.channels {
4191        update.channels.push(channel.to_proto());
4192    }
4193
4194    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4195    update.latest_channel_message_ids = channels.latest_channel_messages;
4196
4197    for (channel_id, participants) in channels.channel_participants {
4198        update
4199            .channel_participants
4200            .push(proto::ChannelParticipants {
4201                channel_id: channel_id.to_proto(),
4202                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4203            });
4204    }
4205
4206    for channel in channels.invited_channels {
4207        update.channel_invitations.push(channel.to_proto());
4208    }
4209
4210    update
4211}
4212
4213fn build_initial_contacts_update(
4214    contacts: Vec<db::Contact>,
4215    pool: &ConnectionPool,
4216) -> proto::UpdateContacts {
4217    let mut update = proto::UpdateContacts::default();
4218
4219    for contact in contacts {
4220        match contact {
4221            db::Contact::Accepted { user_id, busy } => {
4222                update.contacts.push(contact_for_user(user_id, busy, pool));
4223            }
4224            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4225            db::Contact::Incoming { user_id } => {
4226                update
4227                    .incoming_requests
4228                    .push(proto::IncomingContactRequest {
4229                        requester_id: user_id.to_proto(),
4230                    })
4231            }
4232        }
4233    }
4234
4235    update
4236}
4237
4238fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4239    proto::Contact {
4240        user_id: user_id.to_proto(),
4241        online: pool.is_user_online(user_id),
4242        busy,
4243    }
4244}
4245
4246fn room_updated(room: &proto::Room, peer: &Peer) {
4247    broadcast(
4248        None,
4249        room.participants
4250            .iter()
4251            .filter_map(|participant| Some(participant.peer_id?.into())),
4252        |peer_id| {
4253            peer.send(
4254                peer_id,
4255                proto::RoomUpdated {
4256                    room: Some(room.clone()),
4257                },
4258            )
4259        },
4260    );
4261}
4262
4263fn channel_updated(
4264    channel: &db::channel::Model,
4265    room: &proto::Room,
4266    peer: &Peer,
4267    pool: &ConnectionPool,
4268) {
4269    let participants = room
4270        .participants
4271        .iter()
4272        .map(|p| p.user_id)
4273        .collect::<Vec<_>>();
4274
4275    broadcast(
4276        None,
4277        pool.channel_connection_ids(channel.root_id())
4278            .filter_map(|(channel_id, role)| {
4279                role.can_see_channel(channel.visibility)
4280                    .then_some(channel_id)
4281            }),
4282        |peer_id| {
4283            peer.send(
4284                peer_id,
4285                proto::UpdateChannels {
4286                    channel_participants: vec![proto::ChannelParticipants {
4287                        channel_id: channel.id.to_proto(),
4288                        participant_user_ids: participants.clone(),
4289                    }],
4290                    ..Default::default()
4291                },
4292            )
4293        },
4294    );
4295}
4296
4297async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4298    let db = session.db().await;
4299
4300    let contacts = db.get_contacts(user_id).await?;
4301    let busy = db.is_user_busy(user_id).await?;
4302
4303    let pool = session.connection_pool().await;
4304    let updated_contact = contact_for_user(user_id, busy, &pool);
4305    for contact in contacts {
4306        if let db::Contact::Accepted {
4307            user_id: contact_user_id,
4308            ..
4309        } = contact
4310        {
4311            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4312                session
4313                    .peer
4314                    .send(
4315                        contact_conn_id,
4316                        proto::UpdateContacts {
4317                            contacts: vec![updated_contact.clone()],
4318                            remove_contacts: Default::default(),
4319                            incoming_requests: Default::default(),
4320                            remove_incoming_requests: Default::default(),
4321                            outgoing_requests: Default::default(),
4322                            remove_outgoing_requests: Default::default(),
4323                        },
4324                    )
4325                    .trace_err();
4326            }
4327        }
4328    }
4329    Ok(())
4330}
4331
4332async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4333    let mut contacts_to_update = HashSet::default();
4334
4335    let room_id;
4336    let canceled_calls_to_user_ids;
4337    let livekit_room;
4338    let delete_livekit_room;
4339    let room;
4340    let channel;
4341
4342    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4343        contacts_to_update.insert(session.user_id());
4344
4345        for project in left_room.left_projects.values() {
4346            project_left(project, session);
4347        }
4348
4349        room_id = RoomId::from_proto(left_room.room.id);
4350        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4351        livekit_room = mem::take(&mut left_room.room.livekit_room);
4352        delete_livekit_room = left_room.deleted;
4353        room = mem::take(&mut left_room.room);
4354        channel = mem::take(&mut left_room.channel);
4355
4356        room_updated(&room, &session.peer);
4357    } else {
4358        return Ok(());
4359    }
4360
4361    if let Some(channel) = channel {
4362        channel_updated(
4363            &channel,
4364            &room,
4365            &session.peer,
4366            &*session.connection_pool().await,
4367        );
4368    }
4369
4370    {
4371        let pool = session.connection_pool().await;
4372        for canceled_user_id in canceled_calls_to_user_ids {
4373            for connection_id in pool.user_connection_ids(canceled_user_id) {
4374                session
4375                    .peer
4376                    .send(
4377                        connection_id,
4378                        proto::CallCanceled {
4379                            room_id: room_id.to_proto(),
4380                        },
4381                    )
4382                    .trace_err();
4383            }
4384            contacts_to_update.insert(canceled_user_id);
4385        }
4386    }
4387
4388    for contact_user_id in contacts_to_update {
4389        update_user_contacts(contact_user_id, session).await?;
4390    }
4391
4392    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4393        live_kit
4394            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4395            .await
4396            .trace_err();
4397
4398        if delete_livekit_room {
4399            live_kit.delete_room(livekit_room).await.trace_err();
4400        }
4401    }
4402
4403    Ok(())
4404}
4405
4406async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4407    let left_channel_buffers = session
4408        .db()
4409        .await
4410        .leave_channel_buffers(session.connection_id)
4411        .await?;
4412
4413    for left_buffer in left_channel_buffers {
4414        channel_buffer_updated(
4415            session.connection_id,
4416            left_buffer.connections,
4417            &proto::UpdateChannelBufferCollaborators {
4418                channel_id: left_buffer.channel_id.to_proto(),
4419                collaborators: left_buffer.collaborators,
4420            },
4421            &session.peer,
4422        );
4423    }
4424
4425    Ok(())
4426}
4427
4428fn project_left(project: &db::LeftProject, session: &Session) {
4429    for connection_id in &project.connection_ids {
4430        if project.should_unshare {
4431            session
4432                .peer
4433                .send(
4434                    *connection_id,
4435                    proto::UnshareProject {
4436                        project_id: project.id.to_proto(),
4437                    },
4438                )
4439                .trace_err();
4440        } else {
4441            session
4442                .peer
4443                .send(
4444                    *connection_id,
4445                    proto::RemoveProjectCollaborator {
4446                        project_id: project.id.to_proto(),
4447                        peer_id: Some(session.connection_id.into()),
4448                    },
4449                )
4450                .trace_err();
4451        }
4452    }
4453}
4454
4455pub trait ResultExt {
4456    type Ok;
4457
4458    fn trace_err(self) -> Option<Self::Ok>;
4459}
4460
4461impl<T, E> ResultExt for Result<T, E>
4462where
4463    E: std::fmt::Debug,
4464{
4465    type Ok = T;
4466
4467    #[track_caller]
4468    fn trace_err(self) -> Option<T> {
4469        match self {
4470            Ok(value) => Some(value),
4471            Err(error) => {
4472                tracing::error!("{:?}", error);
4473                None
4474            }
4475        }
4476    }
4477}