rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   8use crate::stripe_client::StripeCustomerId;
   9use crate::{
  10    AppState, Error, Result, auth,
  11    db::{
  12        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  13        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  14        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  15        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  16    },
  17    executor::Executor,
  18};
  19use anyhow::{Context as _, anyhow, bail};
  20use async_tungstenite::tungstenite::{
  21    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  22};
  23use axum::{
  24    Extension, Router, TypedHeader,
  25    body::Body,
  26    extract::{
  27        ConnectInfo, WebSocketUpgrade,
  28        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  29    },
  30    headers::{Header, HeaderName},
  31    http::StatusCode,
  32    middleware,
  33    response::IntoResponse,
  34    routing::get,
  35};
  36use chrono::Utc;
  37use collections::{HashMap, HashSet};
  38pub use connection_pool::{ConnectionPool, ZedVersion};
  39use core::fmt::{self, Debug, Formatter};
  40use reqwest_client::ReqwestClient;
  41use rpc::proto::split_repository_update;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  46    stream::FuturesUnordered,
  47};
  48use prometheus::{IntGauge, register_int_gauge};
  49use rpc::{
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        Arc, OnceLock,
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{Semaphore, watch};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    Instrument,
  77    field::{self},
  78    info_span, instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111}
 112
 113impl Principal {
 114    fn user(&self) -> &User {
 115        match self {
 116            Principal::User(user) => user,
 117            Principal::Impersonated { user, .. } => user,
 118        }
 119    }
 120
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132        }
 133    }
 134}
 135
 136#[derive(Clone)]
 137struct Session {
 138    principal: Principal,
 139    connection_id: ConnectionId,
 140    db: Arc<tokio::sync::Mutex<DbHandle>>,
 141    peer: Arc<Peer>,
 142    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    system_id: Option<String>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn is_staff(&self) -> bool {
 173        match &self.principal {
 174            Principal::User(user) => user.admin,
 175            Principal::Impersonated { .. } => true,
 176        }
 177    }
 178
 179    fn user_id(&self) -> UserId {
 180        match &self.principal {
 181            Principal::User(user) => user.id,
 182            Principal::Impersonated { user, .. } => user.id,
 183        }
 184    }
 185
 186    pub fn email(&self) -> Option<String> {
 187        match &self.principal {
 188            Principal::User(user) => user.email_address.clone(),
 189            Principal::Impersonated { user, .. } => user.email_address.clone(),
 190        }
 191    }
 192}
 193
 194impl Debug for Session {
 195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 196        let mut result = f.debug_struct("Session");
 197        match &self.principal {
 198            Principal::User(user) => {
 199                result.field("user", &user.github_login);
 200            }
 201            Principal::Impersonated { user, admin } => {
 202                result.field("user", &user.github_login);
 203                result.field("impersonator", &admin.github_login);
 204            }
 205        }
 206        result.field("connection_id", &self.connection_id).finish()
 207    }
 208}
 209
 210struct DbHandle(Arc<Database>);
 211
 212impl Deref for DbHandle {
 213    type Target = Database;
 214
 215    fn deref(&self) -> &Self::Target {
 216        self.0.as_ref()
 217    }
 218}
 219
 220pub struct Server {
 221    id: parking_lot::Mutex<ServerId>,
 222    peer: Arc<Peer>,
 223    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 224    app_state: Arc<AppState>,
 225    handlers: HashMap<TypeId, MessageHandler>,
 226    teardown: watch::Sender<bool>,
 227}
 228
 229pub(crate) struct ConnectionPoolGuard<'a> {
 230    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 231    _not_send: PhantomData<Rc<()>>,
 232}
 233
 234#[derive(Serialize)]
 235pub struct ServerSnapshot<'a> {
 236    peer: &'a Peer,
 237    #[serde(serialize_with = "serialize_deref")]
 238    connection_pool: ConnectionPoolGuard<'a>,
 239}
 240
 241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 242where
 243    S: Serializer,
 244    T: Deref<Target = U>,
 245    U: Serialize,
 246{
 247    Serialize::serialize(value.deref(), serializer)
 248}
 249
 250impl Server {
 251    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 252        let mut server = Self {
 253            id: parking_lot::Mutex::new(id),
 254            peer: Peer::new(id.0 as u32),
 255            app_state: app_state.clone(),
 256            connection_pool: Default::default(),
 257            handlers: Default::default(),
 258            teardown: watch::channel(false).0,
 259        };
 260
 261        server
 262            .add_request_handler(ping)
 263            .add_request_handler(create_room)
 264            .add_request_handler(join_room)
 265            .add_request_handler(rejoin_room)
 266            .add_request_handler(leave_room)
 267            .add_request_handler(set_room_participant_role)
 268            .add_request_handler(call)
 269            .add_request_handler(cancel_call)
 270            .add_message_handler(decline_call)
 271            .add_request_handler(update_participant_location)
 272            .add_request_handler(share_project)
 273            .add_message_handler(unshare_project)
 274            .add_request_handler(join_project)
 275            .add_message_handler(leave_project)
 276            .add_request_handler(update_project)
 277            .add_request_handler(update_worktree)
 278            .add_request_handler(update_repository)
 279            .add_request_handler(remove_repository)
 280            .add_message_handler(start_language_server)
 281            .add_message_handler(update_language_server)
 282            .add_message_handler(update_diagnostic_summary)
 283            .add_message_handler(update_worktree_settings)
 284            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 285            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 286            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 288            .add_request_handler(forward_find_search_candidates_request)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 291            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 293            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 294            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 295            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 296            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 297            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 298            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 300            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 303            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 304            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 305            .add_request_handler(
 306                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 307            )
 308            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 309            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 310            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 312            .add_request_handler(
 313                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 314            )
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 349            .add_message_handler(create_buffer_for_peer)
 350            .add_request_handler(update_buffer)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 357            .add_request_handler(get_users)
 358            .add_request_handler(fuzzy_search_users)
 359            .add_request_handler(request_contact)
 360            .add_request_handler(remove_contact)
 361            .add_request_handler(respond_to_contact_request)
 362            .add_message_handler(subscribe_to_channels)
 363            .add_request_handler(create_channel)
 364            .add_request_handler(delete_channel)
 365            .add_request_handler(invite_channel_member)
 366            .add_request_handler(remove_channel_member)
 367            .add_request_handler(set_channel_member_role)
 368            .add_request_handler(set_channel_visibility)
 369            .add_request_handler(rename_channel)
 370            .add_request_handler(join_channel_buffer)
 371            .add_request_handler(leave_channel_buffer)
 372            .add_message_handler(update_channel_buffer)
 373            .add_request_handler(rejoin_channel_buffers)
 374            .add_request_handler(get_channel_members)
 375            .add_request_handler(respond_to_channel_invite)
 376            .add_request_handler(join_channel)
 377            .add_request_handler(join_channel_chat)
 378            .add_message_handler(leave_channel_chat)
 379            .add_request_handler(send_channel_message)
 380            .add_request_handler(remove_channel_message)
 381            .add_request_handler(update_channel_message)
 382            .add_request_handler(get_channel_messages)
 383            .add_request_handler(get_channel_messages_by_id)
 384            .add_request_handler(get_notifications)
 385            .add_request_handler(mark_notification_as_read)
 386            .add_request_handler(move_channel)
 387            .add_request_handler(follow)
 388            .add_message_handler(unfollow)
 389            .add_message_handler(update_followers)
 390            .add_request_handler(get_private_user_info)
 391            .add_request_handler(get_llm_api_token)
 392            .add_request_handler(accept_terms_of_service)
 393            .add_message_handler(acknowledge_channel_message)
 394            .add_message_handler(acknowledge_buffer_version)
 395            .add_request_handler(get_supermaven_api_key)
 396            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 398            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 402            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 405            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 408            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 409            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 411            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 412            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 415            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 416            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 417            .add_message_handler(update_context);
 418
 419        Arc::new(server)
 420    }
 421
 422    pub async fn start(&self) -> Result<()> {
 423        let server_id = *self.id.lock();
 424        let app_state = self.app_state.clone();
 425        let peer = self.peer.clone();
 426        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 427        let pool = self.connection_pool.clone();
 428        let livekit_client = self.app_state.livekit_client.clone();
 429
 430        let span = info_span!("start server");
 431        self.app_state.executor.spawn_detached(
 432            async move {
 433                tracing::info!("waiting for cleanup timeout");
 434                timeout.await;
 435                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 436
 437                app_state
 438                    .db
 439                    .delete_stale_channel_chat_participants(
 440                        &app_state.config.zed_environment,
 441                        server_id,
 442                    )
 443                    .await
 444                    .trace_err();
 445
 446                if let Some((room_ids, channel_ids)) = app_state
 447                    .db
 448                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 449                    .await
 450                    .trace_err()
 451                {
 452                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 453                    tracing::info!(
 454                        stale_channel_buffer_count = channel_ids.len(),
 455                        "retrieved stale channel buffers"
 456                    );
 457
 458                    for channel_id in channel_ids {
 459                        if let Some(refreshed_channel_buffer) = app_state
 460                            .db
 461                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 462                            .await
 463                            .trace_err()
 464                        {
 465                            for connection_id in refreshed_channel_buffer.connection_ids {
 466                                peer.send(
 467                                    connection_id,
 468                                    proto::UpdateChannelBufferCollaborators {
 469                                        channel_id: channel_id.to_proto(),
 470                                        collaborators: refreshed_channel_buffer
 471                                            .collaborators
 472                                            .clone(),
 473                                    },
 474                                )
 475                                .trace_err();
 476                            }
 477                        }
 478                    }
 479
 480                    for room_id in room_ids {
 481                        let mut contacts_to_update = HashSet::default();
 482                        let mut canceled_calls_to_user_ids = Vec::new();
 483                        let mut livekit_room = String::new();
 484                        let mut delete_livekit_room = false;
 485
 486                        if let Some(mut refreshed_room) = app_state
 487                            .db
 488                            .clear_stale_room_participants(room_id, server_id)
 489                            .await
 490                            .trace_err()
 491                        {
 492                            tracing::info!(
 493                                room_id = room_id.0,
 494                                new_participant_count = refreshed_room.room.participants.len(),
 495                                "refreshed room"
 496                            );
 497                            room_updated(&refreshed_room.room, &peer);
 498                            if let Some(channel) = refreshed_room.channel.as_ref() {
 499                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 500                            }
 501                            contacts_to_update
 502                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 503                            contacts_to_update
 504                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 505                            canceled_calls_to_user_ids =
 506                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 507                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 508                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 509                        }
 510
 511                        {
 512                            let pool = pool.lock();
 513                            for canceled_user_id in canceled_calls_to_user_ids {
 514                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 515                                    peer.send(
 516                                        connection_id,
 517                                        proto::CallCanceled {
 518                                            room_id: room_id.to_proto(),
 519                                        },
 520                                    )
 521                                    .trace_err();
 522                                }
 523                            }
 524                        }
 525
 526                        for user_id in contacts_to_update {
 527                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 528                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 529                            if let Some((busy, contacts)) = busy.zip(contacts) {
 530                                let pool = pool.lock();
 531                                let updated_contact = contact_for_user(user_id, busy, &pool);
 532                                for contact in contacts {
 533                                    if let db::Contact::Accepted {
 534                                        user_id: contact_user_id,
 535                                        ..
 536                                    } = contact
 537                                    {
 538                                        for contact_conn_id in
 539                                            pool.user_connection_ids(contact_user_id)
 540                                        {
 541                                            peer.send(
 542                                                contact_conn_id,
 543                                                proto::UpdateContacts {
 544                                                    contacts: vec![updated_contact.clone()],
 545                                                    remove_contacts: Default::default(),
 546                                                    incoming_requests: Default::default(),
 547                                                    remove_incoming_requests: Default::default(),
 548                                                    outgoing_requests: Default::default(),
 549                                                    remove_outgoing_requests: Default::default(),
 550                                                },
 551                                            )
 552                                            .trace_err();
 553                                        }
 554                                    }
 555                                }
 556                            }
 557                        }
 558
 559                        if let Some(live_kit) = livekit_client.as_ref() {
 560                            if delete_livekit_room {
 561                                live_kit.delete_room(livekit_room).await.trace_err();
 562                            }
 563                        }
 564                    }
 565                }
 566
 567                app_state
 568                    .db
 569                    .delete_stale_channel_chat_participants(
 570                        &app_state.config.zed_environment,
 571                        server_id,
 572                    )
 573                    .await
 574                    .trace_err();
 575
 576                app_state
 577                    .db
 578                    .clear_old_worktree_entries(server_id)
 579                    .await
 580                    .trace_err();
 581
 582                app_state
 583                    .db
 584                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 585                    .await
 586                    .trace_err();
 587            }
 588            .instrument(span),
 589        );
 590        Ok(())
 591    }
 592
 593    pub fn teardown(&self) {
 594        self.peer.teardown();
 595        self.connection_pool.lock().reset();
 596        let _ = self.teardown.send(true);
 597    }
 598
 599    #[cfg(test)]
 600    pub fn reset(&self, id: ServerId) {
 601        self.teardown();
 602        *self.id.lock() = id;
 603        self.peer.reset(id.0 as u32);
 604        let _ = self.teardown.send(false);
 605    }
 606
 607    #[cfg(test)]
 608    pub fn id(&self) -> ServerId {
 609        *self.id.lock()
 610    }
 611
 612    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 613    where
 614        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 615        Fut: 'static + Send + Future<Output = Result<()>>,
 616        M: EnvelopedMessage,
 617    {
 618        let prev_handler = self.handlers.insert(
 619            TypeId::of::<M>(),
 620            Box::new(move |envelope, session| {
 621                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 622                let received_at = envelope.received_at;
 623                tracing::info!("message received");
 624                let start_time = Instant::now();
 625                let future = (handler)(*envelope, session);
 626                async move {
 627                    let result = future.await;
 628                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 629                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 630                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 631                    let payload_type = M::NAME;
 632
 633                    match result {
 634                        Err(error) => {
 635                            tracing::error!(
 636                                ?error,
 637                                total_duration_ms,
 638                                processing_duration_ms,
 639                                queue_duration_ms,
 640                                payload_type,
 641                                "error handling message"
 642                            )
 643                        }
 644                        Ok(()) => tracing::info!(
 645                            total_duration_ms,
 646                            processing_duration_ms,
 647                            queue_duration_ms,
 648                            "finished handling message"
 649                        ),
 650                    }
 651                }
 652                .boxed()
 653            }),
 654        );
 655        if prev_handler.is_some() {
 656            panic!("registered a handler for the same message twice");
 657        }
 658        self
 659    }
 660
 661    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 662    where
 663        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 664        Fut: 'static + Send + Future<Output = Result<()>>,
 665        M: EnvelopedMessage,
 666    {
 667        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 668        self
 669    }
 670
 671    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 674        Fut: Send + Future<Output = Result<()>>,
 675        M: RequestMessage,
 676    {
 677        let handler = Arc::new(handler);
 678        self.add_handler(move |envelope, session| {
 679            let receipt = envelope.receipt();
 680            let handler = handler.clone();
 681            async move {
 682                let peer = session.peer.clone();
 683                let responded = Arc::new(AtomicBool::default());
 684                let response = Response {
 685                    peer: peer.clone(),
 686                    responded: responded.clone(),
 687                    receipt,
 688                };
 689                match (handler)(envelope.payload, response, session).await {
 690                    Ok(()) => {
 691                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 692                            Ok(())
 693                        } else {
 694                            Err(anyhow!("handler did not send a response"))?
 695                        }
 696                    }
 697                    Err(error) => {
 698                        let proto_err = match &error {
 699                            Error::Internal(err) => err.to_proto(),
 700                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 701                        };
 702                        peer.respond_with_error(receipt, proto_err)?;
 703                        Err(error)
 704                    }
 705                }
 706            }
 707        })
 708    }
 709
 710    pub fn handle_connection(
 711        self: &Arc<Self>,
 712        connection: Connection,
 713        address: String,
 714        principal: Principal,
 715        zed_version: ZedVersion,
 716        geoip_country_code: Option<String>,
 717        system_id: Option<String>,
 718        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 719        executor: Executor,
 720    ) -> impl Future<Output = ()> + use<> {
 721        let this = self.clone();
 722        let span = info_span!("handle connection", %address,
 723            connection_id=field::Empty,
 724            user_id=field::Empty,
 725            login=field::Empty,
 726            impersonator=field::Empty,
 727            geoip_country_code=field::Empty
 728        );
 729        principal.update_span(&span);
 730        if let Some(country_code) = geoip_country_code.as_ref() {
 731            span.record("geoip_country_code", country_code);
 732        }
 733
 734        let mut teardown = self.teardown.subscribe();
 735        async move {
 736            if *teardown.borrow() {
 737                tracing::error!("server is tearing down");
 738                return
 739            }
 740            let (connection_id, handle_io, mut incoming_rx) = this
 741                .peer
 742                .add_connection(connection, {
 743                    let executor = executor.clone();
 744                    move |duration| executor.sleep(duration)
 745                });
 746            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 747
 748            tracing::info!("connection opened");
 749
 750            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 751            let http_client = match ReqwestClient::user_agent(&user_agent) {
 752                Ok(http_client) => Arc::new(http_client),
 753                Err(error) => {
 754                    tracing::error!(?error, "failed to create HTTP client");
 755                    return;
 756                }
 757            };
 758
 759            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 760                    supermaven_admin_api_key.to_string(),
 761                    http_client.clone(),
 762                )));
 763
 764            let session = Session {
 765                principal: principal.clone(),
 766                connection_id,
 767                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 768                peer: this.peer.clone(),
 769                connection_pool: this.connection_pool.clone(),
 770                app_state: this.app_state.clone(),
 771                geoip_country_code,
 772                system_id,
 773                _executor: executor.clone(),
 774                supermaven_client,
 775            };
 776
 777            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 778                tracing::error!(?error, "failed to send initial client update");
 779                return;
 780            }
 781
 782            let handle_io = handle_io.fuse();
 783            futures::pin_mut!(handle_io);
 784
 785            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 786            // This prevents deadlocks when e.g., client A performs a request to client B and
 787            // client B performs a request to client A. If both clients stop processing further
 788            // messages until their respective request completes, they won't have a chance to
 789            // respond to the other client's request and cause a deadlock.
 790            //
 791            // This arrangement ensures we will attempt to process earlier messages first, but fall
 792            // back to processing messages arrived later in the spirit of making progress.
 793            let mut foreground_message_handlers = FuturesUnordered::new();
 794            let concurrent_handlers = Arc::new(Semaphore::new(256));
 795            loop {
 796                let next_message = async {
 797                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 798                    let message = incoming_rx.next().await;
 799                    (permit, message)
 800                }.fuse();
 801                futures::pin_mut!(next_message);
 802                futures::select_biased! {
 803                    _ = teardown.changed().fuse() => return,
 804                    result = handle_io => {
 805                        if let Err(error) = result {
 806                            tracing::error!(?error, "error handling I/O");
 807                        }
 808                        break;
 809                    }
 810                    _ = foreground_message_handlers.next() => {}
 811                    next_message = next_message => {
 812                        let (permit, message) = next_message;
 813                        if let Some(message) = message {
 814                            let type_name = message.payload_type_name();
 815                            // note: we copy all the fields from the parent span so we can query them in the logs.
 816                            // (https://github.com/tokio-rs/tracing/issues/2670).
 817                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 818                                user_id=field::Empty,
 819                                login=field::Empty,
 820                                impersonator=field::Empty,
 821                            );
 822                            principal.update_span(&span);
 823                            let span_enter = span.enter();
 824                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 825                                let is_background = message.is_background();
 826                                let handle_message = (handler)(message, session.clone());
 827                                drop(span_enter);
 828
 829                                let handle_message = async move {
 830                                    handle_message.await;
 831                                    drop(permit);
 832                                }.instrument(span);
 833                                if is_background {
 834                                    executor.spawn_detached(handle_message);
 835                                } else {
 836                                    foreground_message_handlers.push(handle_message);
 837                                }
 838                            } else {
 839                                tracing::error!("no message handler");
 840                            }
 841                        } else {
 842                            tracing::info!("connection closed");
 843                            break;
 844                        }
 845                    }
 846                }
 847            }
 848
 849            drop(foreground_message_handlers);
 850            tracing::info!("signing out");
 851            if let Err(error) = connection_lost(session, teardown, executor).await {
 852                tracing::error!(?error, "error signing out");
 853            }
 854
 855        }.instrument(span)
 856    }
 857
 858    async fn send_initial_client_update(
 859        &self,
 860        connection_id: ConnectionId,
 861        zed_version: ZedVersion,
 862        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 863        session: &Session,
 864    ) -> Result<()> {
 865        self.peer.send(
 866            connection_id,
 867            proto::Hello {
 868                peer_id: Some(connection_id.into()),
 869            },
 870        )?;
 871        tracing::info!("sent hello message");
 872        if let Some(send_connection_id) = send_connection_id.take() {
 873            let _ = send_connection_id.send(connection_id);
 874        }
 875
 876        match &session.principal {
 877            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 878                if !user.connected_once {
 879                    self.peer.send(connection_id, proto::ShowContacts {})?;
 880                    self.app_state
 881                        .db
 882                        .set_user_connected_once(user.id, true)
 883                        .await?;
 884                }
 885
 886                update_user_plan(session).await?;
 887
 888                let contacts = self.app_state.db.get_contacts(user.id).await?;
 889
 890                {
 891                    let mut pool = self.connection_pool.lock();
 892                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 893                    self.peer.send(
 894                        connection_id,
 895                        build_initial_contacts_update(contacts, &pool),
 896                    )?;
 897                }
 898
 899                if should_auto_subscribe_to_channels(zed_version) {
 900                    subscribe_user_to_channels(user.id, session).await?;
 901                }
 902
 903                if let Some(incoming_call) =
 904                    self.app_state.db.incoming_call_for_user(user.id).await?
 905                {
 906                    self.peer.send(connection_id, incoming_call)?;
 907                }
 908
 909                update_user_contacts(user.id, session).await?;
 910            }
 911        }
 912
 913        Ok(())
 914    }
 915
 916    pub async fn invite_code_redeemed(
 917        self: &Arc<Self>,
 918        inviter_id: UserId,
 919        invitee_id: UserId,
 920    ) -> Result<()> {
 921        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 922            if let Some(code) = &user.invite_code {
 923                let pool = self.connection_pool.lock();
 924                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 925                for connection_id in pool.user_connection_ids(inviter_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateContacts {
 929                            contacts: vec![invitee_contact.clone()],
 930                            ..Default::default()
 931                        },
 932                    )?;
 933                    self.peer.send(
 934                        connection_id,
 935                        proto::UpdateInviteInfo {
 936                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 937                            count: user.invite_count as u32,
 938                        },
 939                    )?;
 940                }
 941            }
 942        }
 943        Ok(())
 944    }
 945
 946    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 947        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 948            if let Some(invite_code) = &user.invite_code {
 949                let pool = self.connection_pool.lock();
 950                for connection_id in pool.user_connection_ids(user_id) {
 951                    self.peer.send(
 952                        connection_id,
 953                        proto::UpdateInviteInfo {
 954                            url: format!(
 955                                "{}{}",
 956                                self.app_state.config.invite_link_prefix, invite_code
 957                            ),
 958                            count: user.invite_count as u32,
 959                        },
 960                    )?;
 961                }
 962            }
 963        }
 964        Ok(())
 965    }
 966
 967    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 968        let user = self
 969            .app_state
 970            .db
 971            .get_user_by_id(user_id)
 972            .await?
 973            .context("user not found")?;
 974
 975        let update_user_plan = make_update_user_plan_message(
 976            &user,
 977            user.admin,
 978            &self.app_state.db,
 979            self.app_state.llm_db.clone(),
 980        )
 981        .await?;
 982
 983        let pool = self.connection_pool.lock();
 984        for connection_id in pool.user_connection_ids(user_id) {
 985            self.peer
 986                .send(connection_id, update_user_plan.clone())
 987                .trace_err();
 988        }
 989
 990        Ok(())
 991    }
 992
 993    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 994        let pool = self.connection_pool.lock();
 995        for connection_id in pool.user_connection_ids(user_id) {
 996            self.peer
 997                .send(connection_id, proto::RefreshLlmToken {})
 998                .trace_err();
 999        }
1000    }
1001
1002    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1003        ServerSnapshot {
1004            connection_pool: ConnectionPoolGuard {
1005                guard: self.connection_pool.lock(),
1006                _not_send: PhantomData,
1007            },
1008            peer: &self.peer,
1009        }
1010    }
1011}
1012
1013impl Deref for ConnectionPoolGuard<'_> {
1014    type Target = ConnectionPool;
1015
1016    fn deref(&self) -> &Self::Target {
1017        &self.guard
1018    }
1019}
1020
1021impl DerefMut for ConnectionPoolGuard<'_> {
1022    fn deref_mut(&mut self) -> &mut Self::Target {
1023        &mut self.guard
1024    }
1025}
1026
1027impl Drop for ConnectionPoolGuard<'_> {
1028    fn drop(&mut self) {
1029        #[cfg(test)]
1030        self.check_invariants();
1031    }
1032}
1033
1034fn broadcast<F>(
1035    sender_id: Option<ConnectionId>,
1036    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1037    mut f: F,
1038) where
1039    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1040{
1041    for receiver_id in receiver_ids {
1042        if Some(receiver_id) != sender_id {
1043            if let Err(error) = f(receiver_id) {
1044                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1045            }
1046        }
1047    }
1048}
1049
1050pub struct ProtocolVersion(u32);
1051
1052impl Header for ProtocolVersion {
1053    fn name() -> &'static HeaderName {
1054        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1055        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1056    }
1057
1058    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1059    where
1060        Self: Sized,
1061        I: Iterator<Item = &'i axum::http::HeaderValue>,
1062    {
1063        let version = values
1064            .next()
1065            .ok_or_else(axum::headers::Error::invalid)?
1066            .to_str()
1067            .map_err(|_| axum::headers::Error::invalid())?
1068            .parse()
1069            .map_err(|_| axum::headers::Error::invalid())?;
1070        Ok(Self(version))
1071    }
1072
1073    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1074        values.extend([self.0.to_string().parse().unwrap()]);
1075    }
1076}
1077
1078pub struct AppVersionHeader(SemanticVersion);
1079impl Header for AppVersionHeader {
1080    fn name() -> &'static HeaderName {
1081        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1082        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1083    }
1084
1085    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1086    where
1087        Self: Sized,
1088        I: Iterator<Item = &'i axum::http::HeaderValue>,
1089    {
1090        let version = values
1091            .next()
1092            .ok_or_else(axum::headers::Error::invalid)?
1093            .to_str()
1094            .map_err(|_| axum::headers::Error::invalid())?
1095            .parse()
1096            .map_err(|_| axum::headers::Error::invalid())?;
1097        Ok(Self(version))
1098    }
1099
1100    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1101        values.extend([self.0.to_string().parse().unwrap()]);
1102    }
1103}
1104
1105pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1106    Router::new()
1107        .route("/rpc", get(handle_websocket_request))
1108        .layer(
1109            ServiceBuilder::new()
1110                .layer(Extension(server.app_state.clone()))
1111                .layer(middleware::from_fn(auth::validate_header)),
1112        )
1113        .route("/metrics", get(handle_metrics))
1114        .layer(Extension(server))
1115}
1116
1117pub async fn handle_websocket_request(
1118    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1119    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1120    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1121    Extension(server): Extension<Arc<Server>>,
1122    Extension(principal): Extension<Principal>,
1123    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1124    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1125    ws: WebSocketUpgrade,
1126) -> axum::response::Response {
1127    if protocol_version != rpc::PROTOCOL_VERSION {
1128        return (
1129            StatusCode::UPGRADE_REQUIRED,
1130            "client must be upgraded".to_string(),
1131        )
1132            .into_response();
1133    }
1134
1135    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1136        return (
1137            StatusCode::UPGRADE_REQUIRED,
1138            "no version header found".to_string(),
1139        )
1140            .into_response();
1141    };
1142
1143    if !version.can_collaborate() {
1144        return (
1145            StatusCode::UPGRADE_REQUIRED,
1146            "client must be upgraded".to_string(),
1147        )
1148            .into_response();
1149    }
1150
1151    let socket_address = socket_address.to_string();
1152    ws.on_upgrade(move |socket| {
1153        let socket = socket
1154            .map_ok(to_tungstenite_message)
1155            .err_into()
1156            .with(|message| async move { to_axum_message(message) });
1157        let connection = Connection::new(Box::pin(socket));
1158        async move {
1159            server
1160                .handle_connection(
1161                    connection,
1162                    socket_address,
1163                    principal,
1164                    version,
1165                    country_code_header.map(|header| header.to_string()),
1166                    system_id_header.map(|header| header.to_string()),
1167                    None,
1168                    Executor::Production,
1169                )
1170                .await;
1171        }
1172    })
1173}
1174
1175pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1176    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1177    let connections_metric = CONNECTIONS_METRIC
1178        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1179
1180    let connections = server
1181        .connection_pool
1182        .lock()
1183        .connections()
1184        .filter(|connection| !connection.admin)
1185        .count();
1186    connections_metric.set(connections as _);
1187
1188    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1189    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1190        register_int_gauge!(
1191            "shared_projects",
1192            "number of open projects with one or more guests"
1193        )
1194        .unwrap()
1195    });
1196
1197    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1198    shared_projects_metric.set(shared_projects as _);
1199
1200    let encoder = prometheus::TextEncoder::new();
1201    let metric_families = prometheus::gather();
1202    let encoded_metrics = encoder
1203        .encode_to_string(&metric_families)
1204        .map_err(|err| anyhow!("{err}"))?;
1205    Ok(encoded_metrics)
1206}
1207
1208#[instrument(err, skip(executor))]
1209async fn connection_lost(
1210    session: Session,
1211    mut teardown: watch::Receiver<bool>,
1212    executor: Executor,
1213) -> Result<()> {
1214    session.peer.disconnect(session.connection_id);
1215    session
1216        .connection_pool()
1217        .await
1218        .remove_connection(session.connection_id)?;
1219
1220    session
1221        .db()
1222        .await
1223        .connection_lost(session.connection_id)
1224        .await
1225        .trace_err();
1226
1227    futures::select_biased! {
1228        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1229
1230            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1231            leave_room_for_session(&session, session.connection_id).await.trace_err();
1232            leave_channel_buffers_for_session(&session)
1233                .await
1234                .trace_err();
1235
1236            if !session
1237                .connection_pool()
1238                .await
1239                .is_user_online(session.user_id())
1240            {
1241                let db = session.db().await;
1242                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1243                    room_updated(&room, &session.peer);
1244                }
1245            }
1246
1247            update_user_contacts(session.user_id(), &session).await?;
1248        },
1249        _ = teardown.changed().fuse() => {}
1250    }
1251
1252    Ok(())
1253}
1254
1255/// Acknowledges a ping from a client, used to keep the connection alive.
1256async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1257    response.send(proto::Ack {})?;
1258    Ok(())
1259}
1260
1261/// Creates a new room for calling (outside of channels)
1262async fn create_room(
1263    _request: proto::CreateRoom,
1264    response: Response<proto::CreateRoom>,
1265    session: Session,
1266) -> Result<()> {
1267    let livekit_room = nanoid::nanoid!(30);
1268
1269    let live_kit_connection_info = util::maybe!(async {
1270        let live_kit = session.app_state.livekit_client.as_ref();
1271        let live_kit = live_kit?;
1272        let user_id = session.user_id().to_string();
1273
1274        let token = live_kit
1275            .room_token(&livekit_room, &user_id.to_string())
1276            .trace_err()?;
1277
1278        Some(proto::LiveKitConnectionInfo {
1279            server_url: live_kit.url().into(),
1280            token,
1281            can_publish: true,
1282        })
1283    })
1284    .await;
1285
1286    let room = session
1287        .db()
1288        .await
1289        .create_room(session.user_id(), session.connection_id, &livekit_room)
1290        .await?;
1291
1292    response.send(proto::CreateRoomResponse {
1293        room: Some(room.clone()),
1294        live_kit_connection_info,
1295    })?;
1296
1297    update_user_contacts(session.user_id(), &session).await?;
1298    Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303    request: proto::JoinRoom,
1304    response: Response<proto::JoinRoom>,
1305    session: Session,
1306) -> Result<()> {
1307    let room_id = RoomId::from_proto(request.id);
1308
1309    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311    if let Some(channel_id) = channel_id {
1312        return join_channel_internal(channel_id, Box::new(response), session).await;
1313    }
1314
1315    let joined_room = {
1316        let room = session
1317            .db()
1318            .await
1319            .join_room(room_id, session.user_id(), session.connection_id)
1320            .await?;
1321        room_updated(&room.room, &session.peer);
1322        room.into_inner()
1323    };
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id())
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340
1341    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342    {
1343        live_kit
1344            .room_token(
1345                &joined_room.room.livekit_room,
1346                &session.user_id().to_string(),
1347            )
1348            .trace_err()
1349            .map(|token| proto::LiveKitConnectionInfo {
1350                server_url: live_kit.url().into(),
1351                token,
1352                can_publish: true,
1353            })
1354    } else {
1355        None
1356    };
1357
1358    response.send(proto::JoinRoomResponse {
1359        room: Some(joined_room.room),
1360        channel_id: None,
1361        live_kit_connection_info,
1362    })?;
1363
1364    update_user_contacts(session.user_id(), &session).await?;
1365    Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370    request: proto::RejoinRoom,
1371    response: Response<proto::RejoinRoom>,
1372    session: Session,
1373) -> Result<()> {
1374    let room;
1375    let channel;
1376    {
1377        let mut rejoined_room = session
1378            .db()
1379            .await
1380            .rejoin_room(request, session.user_id(), session.connection_id)
1381            .await?;
1382
1383        response.send(proto::RejoinRoomResponse {
1384            room: Some(rejoined_room.room.clone()),
1385            reshared_projects: rejoined_room
1386                .reshared_projects
1387                .iter()
1388                .map(|project| proto::ResharedProject {
1389                    id: project.id.to_proto(),
1390                    collaborators: project
1391                        .collaborators
1392                        .iter()
1393                        .map(|collaborator| collaborator.to_proto())
1394                        .collect(),
1395                })
1396                .collect(),
1397            rejoined_projects: rejoined_room
1398                .rejoined_projects
1399                .iter()
1400                .map(|rejoined_project| rejoined_project.to_proto())
1401                .collect(),
1402        })?;
1403        room_updated(&rejoined_room.room, &session.peer);
1404
1405        for project in &rejoined_room.reshared_projects {
1406            for collaborator in &project.collaborators {
1407                session
1408                    .peer
1409                    .send(
1410                        collaborator.connection_id,
1411                        proto::UpdateProjectCollaborator {
1412                            project_id: project.id.to_proto(),
1413                            old_peer_id: Some(project.old_connection_id.into()),
1414                            new_peer_id: Some(session.connection_id.into()),
1415                        },
1416                    )
1417                    .trace_err();
1418            }
1419
1420            broadcast(
1421                Some(session.connection_id),
1422                project
1423                    .collaborators
1424                    .iter()
1425                    .map(|collaborator| collaborator.connection_id),
1426                |connection_id| {
1427                    session.peer.forward_send(
1428                        session.connection_id,
1429                        connection_id,
1430                        proto::UpdateProject {
1431                            project_id: project.id.to_proto(),
1432                            worktrees: project.worktrees.clone(),
1433                        },
1434                    )
1435                },
1436            );
1437        }
1438
1439        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441        let rejoined_room = rejoined_room.into_inner();
1442
1443        room = rejoined_room.room;
1444        channel = rejoined_room.channel;
1445    }
1446
1447    if let Some(channel) = channel {
1448        channel_updated(
1449            &channel,
1450            &room,
1451            &session.peer,
1452            &*session.connection_pool().await,
1453        );
1454    }
1455
1456    update_user_contacts(session.user_id(), &session).await?;
1457    Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461    rejoined_projects: &mut Vec<RejoinedProject>,
1462    session: &Session,
1463) -> Result<()> {
1464    for project in rejoined_projects.iter() {
1465        for collaborator in &project.collaborators {
1466            session
1467                .peer
1468                .send(
1469                    collaborator.connection_id,
1470                    proto::UpdateProjectCollaborator {
1471                        project_id: project.id.to_proto(),
1472                        old_peer_id: Some(project.old_connection_id.into()),
1473                        new_peer_id: Some(session.connection_id.into()),
1474                    },
1475                )
1476                .trace_err();
1477        }
1478    }
1479
1480    for project in rejoined_projects {
1481        for worktree in mem::take(&mut project.worktrees) {
1482            // Stream this worktree's entries.
1483            let message = proto::UpdateWorktree {
1484                project_id: project.id.to_proto(),
1485                worktree_id: worktree.id,
1486                abs_path: worktree.abs_path.clone(),
1487                root_name: worktree.root_name,
1488                updated_entries: worktree.updated_entries,
1489                removed_entries: worktree.removed_entries,
1490                scan_id: worktree.scan_id,
1491                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492                updated_repositories: worktree.updated_repositories,
1493                removed_repositories: worktree.removed_repositories,
1494            };
1495            for update in proto::split_worktree_update(message) {
1496                session.peer.send(session.connection_id, update)?;
1497            }
1498
1499            // Stream this worktree's diagnostics.
1500            for summary in worktree.diagnostic_summaries {
1501                session.peer.send(
1502                    session.connection_id,
1503                    proto::UpdateDiagnosticSummary {
1504                        project_id: project.id.to_proto(),
1505                        worktree_id: worktree.id,
1506                        summary: Some(summary),
1507                    },
1508                )?;
1509            }
1510
1511            for settings_file in worktree.settings_files {
1512                session.peer.send(
1513                    session.connection_id,
1514                    proto::UpdateWorktreeSettings {
1515                        project_id: project.id.to_proto(),
1516                        worktree_id: worktree.id,
1517                        path: settings_file.path,
1518                        content: Some(settings_file.content),
1519                        kind: Some(settings_file.kind.to_proto().into()),
1520                    },
1521                )?;
1522            }
1523        }
1524
1525        for repository in mem::take(&mut project.updated_repositories) {
1526            for update in split_repository_update(repository) {
1527                session.peer.send(session.connection_id, update)?;
1528            }
1529        }
1530
1531        for id in mem::take(&mut project.removed_repositories) {
1532            session.peer.send(
1533                session.connection_id,
1534                proto::RemoveRepository {
1535                    project_id: project.id.to_proto(),
1536                    id,
1537                },
1538            )?;
1539        }
1540    }
1541
1542    Ok(())
1543}
1544
1545/// leave room disconnects from the room.
1546async fn leave_room(
1547    _: proto::LeaveRoom,
1548    response: Response<proto::LeaveRoom>,
1549    session: Session,
1550) -> Result<()> {
1551    leave_room_for_session(&session, session.connection_id).await?;
1552    response.send(proto::Ack {})?;
1553    Ok(())
1554}
1555
1556/// Updates the permissions of someone else in the room.
1557async fn set_room_participant_role(
1558    request: proto::SetRoomParticipantRole,
1559    response: Response<proto::SetRoomParticipantRole>,
1560    session: Session,
1561) -> Result<()> {
1562    let user_id = UserId::from_proto(request.user_id);
1563    let role = ChannelRole::from(request.role());
1564
1565    let (livekit_room, can_publish) = {
1566        let room = session
1567            .db()
1568            .await
1569            .set_room_participant_role(
1570                session.user_id(),
1571                RoomId::from_proto(request.room_id),
1572                user_id,
1573                role,
1574            )
1575            .await?;
1576
1577        let livekit_room = room.livekit_room.clone();
1578        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1579        room_updated(&room, &session.peer);
1580        (livekit_room, can_publish)
1581    };
1582
1583    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1584        live_kit
1585            .update_participant(
1586                livekit_room.clone(),
1587                request.user_id.to_string(),
1588                livekit_api::proto::ParticipantPermission {
1589                    can_subscribe: true,
1590                    can_publish,
1591                    can_publish_data: can_publish,
1592                    hidden: false,
1593                    recorder: false,
1594                },
1595            )
1596            .await
1597            .trace_err();
1598    }
1599
1600    response.send(proto::Ack {})?;
1601    Ok(())
1602}
1603
1604/// Call someone else into the current room
1605async fn call(
1606    request: proto::Call,
1607    response: Response<proto::Call>,
1608    session: Session,
1609) -> Result<()> {
1610    let room_id = RoomId::from_proto(request.room_id);
1611    let calling_user_id = session.user_id();
1612    let calling_connection_id = session.connection_id;
1613    let called_user_id = UserId::from_proto(request.called_user_id);
1614    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1615    if !session
1616        .db()
1617        .await
1618        .has_contact(calling_user_id, called_user_id)
1619        .await?
1620    {
1621        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1622    }
1623
1624    let incoming_call = {
1625        let (room, incoming_call) = &mut *session
1626            .db()
1627            .await
1628            .call(
1629                room_id,
1630                calling_user_id,
1631                calling_connection_id,
1632                called_user_id,
1633                initial_project_id,
1634            )
1635            .await?;
1636        room_updated(room, &session.peer);
1637        mem::take(incoming_call)
1638    };
1639    update_user_contacts(called_user_id, &session).await?;
1640
1641    let mut calls = session
1642        .connection_pool()
1643        .await
1644        .user_connection_ids(called_user_id)
1645        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1646        .collect::<FuturesUnordered<_>>();
1647
1648    while let Some(call_response) = calls.next().await {
1649        match call_response.as_ref() {
1650            Ok(_) => {
1651                response.send(proto::Ack {})?;
1652                return Ok(());
1653            }
1654            Err(_) => {
1655                call_response.trace_err();
1656            }
1657        }
1658    }
1659
1660    {
1661        let room = session
1662            .db()
1663            .await
1664            .call_failed(room_id, called_user_id)
1665            .await?;
1666        room_updated(&room, &session.peer);
1667    }
1668    update_user_contacts(called_user_id, &session).await?;
1669
1670    Err(anyhow!("failed to ring user"))?
1671}
1672
1673/// Cancel an outgoing call.
1674async fn cancel_call(
1675    request: proto::CancelCall,
1676    response: Response<proto::CancelCall>,
1677    session: Session,
1678) -> Result<()> {
1679    let called_user_id = UserId::from_proto(request.called_user_id);
1680    let room_id = RoomId::from_proto(request.room_id);
1681    {
1682        let room = session
1683            .db()
1684            .await
1685            .cancel_call(room_id, session.connection_id, called_user_id)
1686            .await?;
1687        room_updated(&room, &session.peer);
1688    }
1689
1690    for connection_id in session
1691        .connection_pool()
1692        .await
1693        .user_connection_ids(called_user_id)
1694    {
1695        session
1696            .peer
1697            .send(
1698                connection_id,
1699                proto::CallCanceled {
1700                    room_id: room_id.to_proto(),
1701                },
1702            )
1703            .trace_err();
1704    }
1705    response.send(proto::Ack {})?;
1706
1707    update_user_contacts(called_user_id, &session).await?;
1708    Ok(())
1709}
1710
1711/// Decline an incoming call.
1712async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1713    let room_id = RoomId::from_proto(message.room_id);
1714    {
1715        let room = session
1716            .db()
1717            .await
1718            .decline_call(Some(room_id), session.user_id())
1719            .await?
1720            .context("declining call")?;
1721        room_updated(&room, &session.peer);
1722    }
1723
1724    for connection_id in session
1725        .connection_pool()
1726        .await
1727        .user_connection_ids(session.user_id())
1728    {
1729        session
1730            .peer
1731            .send(
1732                connection_id,
1733                proto::CallCanceled {
1734                    room_id: room_id.to_proto(),
1735                },
1736            )
1737            .trace_err();
1738    }
1739    update_user_contacts(session.user_id(), &session).await?;
1740    Ok(())
1741}
1742
1743/// Updates other participants in the room with your current location.
1744async fn update_participant_location(
1745    request: proto::UpdateParticipantLocation,
1746    response: Response<proto::UpdateParticipantLocation>,
1747    session: Session,
1748) -> Result<()> {
1749    let room_id = RoomId::from_proto(request.room_id);
1750    let location = request.location.context("invalid location")?;
1751
1752    let db = session.db().await;
1753    let room = db
1754        .update_room_participant_location(room_id, session.connection_id, location)
1755        .await?;
1756
1757    room_updated(&room, &session.peer);
1758    response.send(proto::Ack {})?;
1759    Ok(())
1760}
1761
1762/// Share a project into the room.
1763async fn share_project(
1764    request: proto::ShareProject,
1765    response: Response<proto::ShareProject>,
1766    session: Session,
1767) -> Result<()> {
1768    let (project_id, room) = &*session
1769        .db()
1770        .await
1771        .share_project(
1772            RoomId::from_proto(request.room_id),
1773            session.connection_id,
1774            &request.worktrees,
1775            request.is_ssh_project,
1776        )
1777        .await?;
1778    response.send(proto::ShareProjectResponse {
1779        project_id: project_id.to_proto(),
1780    })?;
1781    room_updated(room, &session.peer);
1782
1783    Ok(())
1784}
1785
1786/// Unshare a project from the room.
1787async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1788    let project_id = ProjectId::from_proto(message.project_id);
1789    unshare_project_internal(project_id, session.connection_id, &session).await
1790}
1791
1792async fn unshare_project_internal(
1793    project_id: ProjectId,
1794    connection_id: ConnectionId,
1795    session: &Session,
1796) -> Result<()> {
1797    let delete = {
1798        let room_guard = session
1799            .db()
1800            .await
1801            .unshare_project(project_id, connection_id)
1802            .await?;
1803
1804        let (delete, room, guest_connection_ids) = &*room_guard;
1805
1806        let message = proto::UnshareProject {
1807            project_id: project_id.to_proto(),
1808        };
1809
1810        broadcast(
1811            Some(connection_id),
1812            guest_connection_ids.iter().copied(),
1813            |conn_id| session.peer.send(conn_id, message.clone()),
1814        );
1815        if let Some(room) = room {
1816            room_updated(room, &session.peer);
1817        }
1818
1819        *delete
1820    };
1821
1822    if delete {
1823        let db = session.db().await;
1824        db.delete_project(project_id).await?;
1825    }
1826
1827    Ok(())
1828}
1829
1830/// Join someone elses shared project.
1831async fn join_project(
1832    request: proto::JoinProject,
1833    response: Response<proto::JoinProject>,
1834    session: Session,
1835) -> Result<()> {
1836    let project_id = ProjectId::from_proto(request.project_id);
1837
1838    tracing::info!(%project_id, "join project");
1839
1840    let db = session.db().await;
1841    let (project, replica_id) = &mut *db
1842        .join_project(project_id, session.connection_id, session.user_id())
1843        .await?;
1844    drop(db);
1845    tracing::info!(%project_id, "join remote project");
1846    join_project_internal(response, session, project, replica_id)
1847}
1848
1849trait JoinProjectInternalResponse {
1850    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1851}
1852impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1853    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1854        Response::<proto::JoinProject>::send(self, result)
1855    }
1856}
1857
1858fn join_project_internal(
1859    response: impl JoinProjectInternalResponse,
1860    session: Session,
1861    project: &mut Project,
1862    replica_id: &ReplicaId,
1863) -> Result<()> {
1864    let collaborators = project
1865        .collaborators
1866        .iter()
1867        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1868        .map(|collaborator| collaborator.to_proto())
1869        .collect::<Vec<_>>();
1870    let project_id = project.id;
1871    let guest_user_id = session.user_id();
1872
1873    let worktrees = project
1874        .worktrees
1875        .iter()
1876        .map(|(id, worktree)| proto::WorktreeMetadata {
1877            id: *id,
1878            root_name: worktree.root_name.clone(),
1879            visible: worktree.visible,
1880            abs_path: worktree.abs_path.clone(),
1881        })
1882        .collect::<Vec<_>>();
1883
1884    let add_project_collaborator = proto::AddProjectCollaborator {
1885        project_id: project_id.to_proto(),
1886        collaborator: Some(proto::Collaborator {
1887            peer_id: Some(session.connection_id.into()),
1888            replica_id: replica_id.0 as u32,
1889            user_id: guest_user_id.to_proto(),
1890            is_host: false,
1891        }),
1892    };
1893
1894    for collaborator in &collaborators {
1895        session
1896            .peer
1897            .send(
1898                collaborator.peer_id.unwrap().into(),
1899                add_project_collaborator.clone(),
1900            )
1901            .trace_err();
1902    }
1903
1904    // First, we send the metadata associated with each worktree.
1905    response.send(proto::JoinProjectResponse {
1906        project_id: project.id.0 as u64,
1907        worktrees: worktrees.clone(),
1908        replica_id: replica_id.0 as u32,
1909        collaborators: collaborators.clone(),
1910        language_servers: project.language_servers.clone(),
1911        role: project.role.into(),
1912    })?;
1913
1914    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1915        // Stream this worktree's entries.
1916        let message = proto::UpdateWorktree {
1917            project_id: project_id.to_proto(),
1918            worktree_id,
1919            abs_path: worktree.abs_path.clone(),
1920            root_name: worktree.root_name,
1921            updated_entries: worktree.entries,
1922            removed_entries: Default::default(),
1923            scan_id: worktree.scan_id,
1924            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1925            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1926            removed_repositories: Default::default(),
1927        };
1928        for update in proto::split_worktree_update(message) {
1929            session.peer.send(session.connection_id, update.clone())?;
1930        }
1931
1932        // Stream this worktree's diagnostics.
1933        for summary in worktree.diagnostic_summaries {
1934            session.peer.send(
1935                session.connection_id,
1936                proto::UpdateDiagnosticSummary {
1937                    project_id: project_id.to_proto(),
1938                    worktree_id: worktree.id,
1939                    summary: Some(summary),
1940                },
1941            )?;
1942        }
1943
1944        for settings_file in worktree.settings_files {
1945            session.peer.send(
1946                session.connection_id,
1947                proto::UpdateWorktreeSettings {
1948                    project_id: project_id.to_proto(),
1949                    worktree_id: worktree.id,
1950                    path: settings_file.path,
1951                    content: Some(settings_file.content),
1952                    kind: Some(settings_file.kind.to_proto() as i32),
1953                },
1954            )?;
1955        }
1956    }
1957
1958    for repository in mem::take(&mut project.repositories) {
1959        for update in split_repository_update(repository) {
1960            session.peer.send(session.connection_id, update)?;
1961        }
1962    }
1963
1964    for language_server in &project.language_servers {
1965        session.peer.send(
1966            session.connection_id,
1967            proto::UpdateLanguageServer {
1968                project_id: project_id.to_proto(),
1969                language_server_id: language_server.id,
1970                variant: Some(
1971                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1972                        proto::LspDiskBasedDiagnosticsUpdated {},
1973                    ),
1974                ),
1975            },
1976        )?;
1977    }
1978
1979    Ok(())
1980}
1981
1982/// Leave someone elses shared project.
1983async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1984    let sender_id = session.connection_id;
1985    let project_id = ProjectId::from_proto(request.project_id);
1986    let db = session.db().await;
1987
1988    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1989    tracing::info!(
1990        %project_id,
1991        "leave project"
1992    );
1993
1994    project_left(project, &session);
1995    if let Some(room) = room {
1996        room_updated(room, &session.peer);
1997    }
1998
1999    Ok(())
2000}
2001
2002/// Updates other participants with changes to the project
2003async fn update_project(
2004    request: proto::UpdateProject,
2005    response: Response<proto::UpdateProject>,
2006    session: Session,
2007) -> Result<()> {
2008    let project_id = ProjectId::from_proto(request.project_id);
2009    let (room, guest_connection_ids) = &*session
2010        .db()
2011        .await
2012        .update_project(project_id, session.connection_id, &request.worktrees)
2013        .await?;
2014    broadcast(
2015        Some(session.connection_id),
2016        guest_connection_ids.iter().copied(),
2017        |connection_id| {
2018            session
2019                .peer
2020                .forward_send(session.connection_id, connection_id, request.clone())
2021        },
2022    );
2023    if let Some(room) = room {
2024        room_updated(room, &session.peer);
2025    }
2026    response.send(proto::Ack {})?;
2027
2028    Ok(())
2029}
2030
2031/// Updates other participants with changes to the worktree
2032async fn update_worktree(
2033    request: proto::UpdateWorktree,
2034    response: Response<proto::UpdateWorktree>,
2035    session: Session,
2036) -> Result<()> {
2037    let guest_connection_ids = session
2038        .db()
2039        .await
2040        .update_worktree(&request, session.connection_id)
2041        .await?;
2042
2043    broadcast(
2044        Some(session.connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |connection_id| {
2047            session
2048                .peer
2049                .forward_send(session.connection_id, connection_id, request.clone())
2050        },
2051    );
2052    response.send(proto::Ack {})?;
2053    Ok(())
2054}
2055
2056async fn update_repository(
2057    request: proto::UpdateRepository,
2058    response: Response<proto::UpdateRepository>,
2059    session: Session,
2060) -> Result<()> {
2061    let guest_connection_ids = session
2062        .db()
2063        .await
2064        .update_repository(&request, session.connection_id)
2065        .await?;
2066
2067    broadcast(
2068        Some(session.connection_id),
2069        guest_connection_ids.iter().copied(),
2070        |connection_id| {
2071            session
2072                .peer
2073                .forward_send(session.connection_id, connection_id, request.clone())
2074        },
2075    );
2076    response.send(proto::Ack {})?;
2077    Ok(())
2078}
2079
2080async fn remove_repository(
2081    request: proto::RemoveRepository,
2082    response: Response<proto::RemoveRepository>,
2083    session: Session,
2084) -> Result<()> {
2085    let guest_connection_ids = session
2086        .db()
2087        .await
2088        .remove_repository(&request, session.connection_id)
2089        .await?;
2090
2091    broadcast(
2092        Some(session.connection_id),
2093        guest_connection_ids.iter().copied(),
2094        |connection_id| {
2095            session
2096                .peer
2097                .forward_send(session.connection_id, connection_id, request.clone())
2098        },
2099    );
2100    response.send(proto::Ack {})?;
2101    Ok(())
2102}
2103
2104/// Updates other participants with changes to the diagnostics
2105async fn update_diagnostic_summary(
2106    message: proto::UpdateDiagnosticSummary,
2107    session: Session,
2108) -> Result<()> {
2109    let guest_connection_ids = session
2110        .db()
2111        .await
2112        .update_diagnostic_summary(&message, session.connection_id)
2113        .await?;
2114
2115    broadcast(
2116        Some(session.connection_id),
2117        guest_connection_ids.iter().copied(),
2118        |connection_id| {
2119            session
2120                .peer
2121                .forward_send(session.connection_id, connection_id, message.clone())
2122        },
2123    );
2124
2125    Ok(())
2126}
2127
2128/// Updates other participants with changes to the worktree settings
2129async fn update_worktree_settings(
2130    message: proto::UpdateWorktreeSettings,
2131    session: Session,
2132) -> Result<()> {
2133    let guest_connection_ids = session
2134        .db()
2135        .await
2136        .update_worktree_settings(&message, session.connection_id)
2137        .await?;
2138
2139    broadcast(
2140        Some(session.connection_id),
2141        guest_connection_ids.iter().copied(),
2142        |connection_id| {
2143            session
2144                .peer
2145                .forward_send(session.connection_id, connection_id, message.clone())
2146        },
2147    );
2148
2149    Ok(())
2150}
2151
2152/// Notify other participants that a language server has started.
2153async fn start_language_server(
2154    request: proto::StartLanguageServer,
2155    session: Session,
2156) -> Result<()> {
2157    let guest_connection_ids = session
2158        .db()
2159        .await
2160        .start_language_server(&request, session.connection_id)
2161        .await?;
2162
2163    broadcast(
2164        Some(session.connection_id),
2165        guest_connection_ids.iter().copied(),
2166        |connection_id| {
2167            session
2168                .peer
2169                .forward_send(session.connection_id, connection_id, request.clone())
2170        },
2171    );
2172    Ok(())
2173}
2174
2175/// Notify other participants that a language server has changed.
2176async fn update_language_server(
2177    request: proto::UpdateLanguageServer,
2178    session: Session,
2179) -> Result<()> {
2180    let project_id = ProjectId::from_proto(request.project_id);
2181    let project_connection_ids = session
2182        .db()
2183        .await
2184        .project_connection_ids(project_id, session.connection_id, true)
2185        .await?;
2186    broadcast(
2187        Some(session.connection_id),
2188        project_connection_ids.iter().copied(),
2189        |connection_id| {
2190            session
2191                .peer
2192                .forward_send(session.connection_id, connection_id, request.clone())
2193        },
2194    );
2195    Ok(())
2196}
2197
2198/// forward a project request to the host. These requests should be read only
2199/// as guests are allowed to send them.
2200async fn forward_read_only_project_request<T>(
2201    request: T,
2202    response: Response<T>,
2203    session: Session,
2204) -> Result<()>
2205where
2206    T: EntityMessage + RequestMessage,
2207{
2208    let project_id = ProjectId::from_proto(request.remote_entity_id());
2209    let host_connection_id = session
2210        .db()
2211        .await
2212        .host_for_read_only_project_request(project_id, session.connection_id)
2213        .await?;
2214    let payload = session
2215        .peer
2216        .forward_request(session.connection_id, host_connection_id, request)
2217        .await?;
2218    response.send(payload)?;
2219    Ok(())
2220}
2221
2222async fn forward_find_search_candidates_request(
2223    request: proto::FindSearchCandidates,
2224    response: Response<proto::FindSearchCandidates>,
2225    session: Session,
2226) -> Result<()> {
2227    let project_id = ProjectId::from_proto(request.remote_entity_id());
2228    let host_connection_id = session
2229        .db()
2230        .await
2231        .host_for_read_only_project_request(project_id, session.connection_id)
2232        .await?;
2233    let payload = session
2234        .peer
2235        .forward_request(session.connection_id, host_connection_id, request)
2236        .await?;
2237    response.send(payload)?;
2238    Ok(())
2239}
2240
2241/// forward a project request to the host. These requests are disallowed
2242/// for guests.
2243async fn forward_mutating_project_request<T>(
2244    request: T,
2245    response: Response<T>,
2246    session: Session,
2247) -> Result<()>
2248where
2249    T: EntityMessage + RequestMessage,
2250{
2251    let project_id = ProjectId::from_proto(request.remote_entity_id());
2252
2253    let host_connection_id = session
2254        .db()
2255        .await
2256        .host_for_mutating_project_request(project_id, session.connection_id)
2257        .await?;
2258    let payload = session
2259        .peer
2260        .forward_request(session.connection_id, host_connection_id, request)
2261        .await?;
2262    response.send(payload)?;
2263    Ok(())
2264}
2265
2266/// Notify other participants that a new buffer has been created
2267async fn create_buffer_for_peer(
2268    request: proto::CreateBufferForPeer,
2269    session: Session,
2270) -> Result<()> {
2271    session
2272        .db()
2273        .await
2274        .check_user_is_project_host(
2275            ProjectId::from_proto(request.project_id),
2276            session.connection_id,
2277        )
2278        .await?;
2279    let peer_id = request.peer_id.context("invalid peer id")?;
2280    session
2281        .peer
2282        .forward_send(session.connection_id, peer_id.into(), request)?;
2283    Ok(())
2284}
2285
2286/// Notify other participants that a buffer has been updated. This is
2287/// allowed for guests as long as the update is limited to selections.
2288async fn update_buffer(
2289    request: proto::UpdateBuffer,
2290    response: Response<proto::UpdateBuffer>,
2291    session: Session,
2292) -> Result<()> {
2293    let project_id = ProjectId::from_proto(request.project_id);
2294    let mut capability = Capability::ReadOnly;
2295
2296    for op in request.operations.iter() {
2297        match op.variant {
2298            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2299            Some(_) => capability = Capability::ReadWrite,
2300        }
2301    }
2302
2303    let host = {
2304        let guard = session
2305            .db()
2306            .await
2307            .connections_for_buffer_update(project_id, session.connection_id, capability)
2308            .await?;
2309
2310        let (host, guests) = &*guard;
2311
2312        broadcast(
2313            Some(session.connection_id),
2314            guests.clone(),
2315            |connection_id| {
2316                session
2317                    .peer
2318                    .forward_send(session.connection_id, connection_id, request.clone())
2319            },
2320        );
2321
2322        *host
2323    };
2324
2325    if host != session.connection_id {
2326        session
2327            .peer
2328            .forward_request(session.connection_id, host, request.clone())
2329            .await?;
2330    }
2331
2332    response.send(proto::Ack {})?;
2333    Ok(())
2334}
2335
2336async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2337    let project_id = ProjectId::from_proto(message.project_id);
2338
2339    let operation = message.operation.as_ref().context("invalid operation")?;
2340    let capability = match operation.variant.as_ref() {
2341        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2342            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2343                match buffer_op.variant {
2344                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2345                        Capability::ReadOnly
2346                    }
2347                    _ => Capability::ReadWrite,
2348                }
2349            } else {
2350                Capability::ReadWrite
2351            }
2352        }
2353        Some(_) => Capability::ReadWrite,
2354        None => Capability::ReadOnly,
2355    };
2356
2357    let guard = session
2358        .db()
2359        .await
2360        .connections_for_buffer_update(project_id, session.connection_id, capability)
2361        .await?;
2362
2363    let (host, guests) = &*guard;
2364
2365    broadcast(
2366        Some(session.connection_id),
2367        guests.iter().chain([host]).copied(),
2368        |connection_id| {
2369            session
2370                .peer
2371                .forward_send(session.connection_id, connection_id, message.clone())
2372        },
2373    );
2374
2375    Ok(())
2376}
2377
2378/// Notify other participants that a project has been updated.
2379async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2380    request: T,
2381    session: Session,
2382) -> Result<()> {
2383    let project_id = ProjectId::from_proto(request.remote_entity_id());
2384    let project_connection_ids = session
2385        .db()
2386        .await
2387        .project_connection_ids(project_id, session.connection_id, false)
2388        .await?;
2389
2390    broadcast(
2391        Some(session.connection_id),
2392        project_connection_ids.iter().copied(),
2393        |connection_id| {
2394            session
2395                .peer
2396                .forward_send(session.connection_id, connection_id, request.clone())
2397        },
2398    );
2399    Ok(())
2400}
2401
2402/// Start following another user in a call.
2403async fn follow(
2404    request: proto::Follow,
2405    response: Response<proto::Follow>,
2406    session: Session,
2407) -> Result<()> {
2408    let room_id = RoomId::from_proto(request.room_id);
2409    let project_id = request.project_id.map(ProjectId::from_proto);
2410    let leader_id = request.leader_id.context("invalid leader id")?.into();
2411    let follower_id = session.connection_id;
2412
2413    session
2414        .db()
2415        .await
2416        .check_room_participants(room_id, leader_id, session.connection_id)
2417        .await?;
2418
2419    let response_payload = session
2420        .peer
2421        .forward_request(session.connection_id, leader_id, request)
2422        .await?;
2423    response.send(response_payload)?;
2424
2425    if let Some(project_id) = project_id {
2426        let room = session
2427            .db()
2428            .await
2429            .follow(room_id, project_id, leader_id, follower_id)
2430            .await?;
2431        room_updated(&room, &session.peer);
2432    }
2433
2434    Ok(())
2435}
2436
2437/// Stop following another user in a call.
2438async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2439    let room_id = RoomId::from_proto(request.room_id);
2440    let project_id = request.project_id.map(ProjectId::from_proto);
2441    let leader_id = request.leader_id.context("invalid leader id")?.into();
2442    let follower_id = session.connection_id;
2443
2444    session
2445        .db()
2446        .await
2447        .check_room_participants(room_id, leader_id, session.connection_id)
2448        .await?;
2449
2450    session
2451        .peer
2452        .forward_send(session.connection_id, leader_id, request)?;
2453
2454    if let Some(project_id) = project_id {
2455        let room = session
2456            .db()
2457            .await
2458            .unfollow(room_id, project_id, leader_id, follower_id)
2459            .await?;
2460        room_updated(&room, &session.peer);
2461    }
2462
2463    Ok(())
2464}
2465
2466/// Notify everyone following you of your current location.
2467async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2468    let room_id = RoomId::from_proto(request.room_id);
2469    let database = session.db.lock().await;
2470
2471    let connection_ids = if let Some(project_id) = request.project_id {
2472        let project_id = ProjectId::from_proto(project_id);
2473        database
2474            .project_connection_ids(project_id, session.connection_id, true)
2475            .await?
2476    } else {
2477        database
2478            .room_connection_ids(room_id, session.connection_id)
2479            .await?
2480    };
2481
2482    // For now, don't send view update messages back to that view's current leader.
2483    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2484        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2485        _ => None,
2486    });
2487
2488    for connection_id in connection_ids.iter().cloned() {
2489        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2490            session
2491                .peer
2492                .forward_send(session.connection_id, connection_id, request.clone())?;
2493        }
2494    }
2495    Ok(())
2496}
2497
2498/// Get public data about users.
2499async fn get_users(
2500    request: proto::GetUsers,
2501    response: Response<proto::GetUsers>,
2502    session: Session,
2503) -> Result<()> {
2504    let user_ids = request
2505        .user_ids
2506        .into_iter()
2507        .map(UserId::from_proto)
2508        .collect();
2509    let users = session
2510        .db()
2511        .await
2512        .get_users_by_ids(user_ids)
2513        .await?
2514        .into_iter()
2515        .map(|user| proto::User {
2516            id: user.id.to_proto(),
2517            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2518            github_login: user.github_login,
2519            email: user.email_address,
2520            name: user.name,
2521        })
2522        .collect();
2523    response.send(proto::UsersResponse { users })?;
2524    Ok(())
2525}
2526
2527/// Search for users (to invite) buy Github login
2528async fn fuzzy_search_users(
2529    request: proto::FuzzySearchUsers,
2530    response: Response<proto::FuzzySearchUsers>,
2531    session: Session,
2532) -> Result<()> {
2533    let query = request.query;
2534    let users = match query.len() {
2535        0 => vec![],
2536        1 | 2 => session
2537            .db()
2538            .await
2539            .get_user_by_github_login(&query)
2540            .await?
2541            .into_iter()
2542            .collect(),
2543        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2544    };
2545    let users = users
2546        .into_iter()
2547        .filter(|user| user.id != session.user_id())
2548        .map(|user| proto::User {
2549            id: user.id.to_proto(),
2550            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2551            github_login: user.github_login,
2552            name: user.name,
2553            email: user.email_address,
2554        })
2555        .collect();
2556    response.send(proto::UsersResponse { users })?;
2557    Ok(())
2558}
2559
2560/// Send a contact request to another user.
2561async fn request_contact(
2562    request: proto::RequestContact,
2563    response: Response<proto::RequestContact>,
2564    session: Session,
2565) -> Result<()> {
2566    let requester_id = session.user_id();
2567    let responder_id = UserId::from_proto(request.responder_id);
2568    if requester_id == responder_id {
2569        return Err(anyhow!("cannot add yourself as a contact"))?;
2570    }
2571
2572    let notifications = session
2573        .db()
2574        .await
2575        .send_contact_request(requester_id, responder_id)
2576        .await?;
2577
2578    // Update outgoing contact requests of requester
2579    let mut update = proto::UpdateContacts::default();
2580    update.outgoing_requests.push(responder_id.to_proto());
2581    for connection_id in session
2582        .connection_pool()
2583        .await
2584        .user_connection_ids(requester_id)
2585    {
2586        session.peer.send(connection_id, update.clone())?;
2587    }
2588
2589    // Update incoming contact requests of responder
2590    let mut update = proto::UpdateContacts::default();
2591    update
2592        .incoming_requests
2593        .push(proto::IncomingContactRequest {
2594            requester_id: requester_id.to_proto(),
2595        });
2596    let connection_pool = session.connection_pool().await;
2597    for connection_id in connection_pool.user_connection_ids(responder_id) {
2598        session.peer.send(connection_id, update.clone())?;
2599    }
2600
2601    send_notifications(&connection_pool, &session.peer, notifications);
2602
2603    response.send(proto::Ack {})?;
2604    Ok(())
2605}
2606
2607/// Accept or decline a contact request
2608async fn respond_to_contact_request(
2609    request: proto::RespondToContactRequest,
2610    response: Response<proto::RespondToContactRequest>,
2611    session: Session,
2612) -> Result<()> {
2613    let responder_id = session.user_id();
2614    let requester_id = UserId::from_proto(request.requester_id);
2615    let db = session.db().await;
2616    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2617        db.dismiss_contact_notification(responder_id, requester_id)
2618            .await?;
2619    } else {
2620        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2621
2622        let notifications = db
2623            .respond_to_contact_request(responder_id, requester_id, accept)
2624            .await?;
2625        let requester_busy = db.is_user_busy(requester_id).await?;
2626        let responder_busy = db.is_user_busy(responder_id).await?;
2627
2628        let pool = session.connection_pool().await;
2629        // Update responder with new contact
2630        let mut update = proto::UpdateContacts::default();
2631        if accept {
2632            update
2633                .contacts
2634                .push(contact_for_user(requester_id, requester_busy, &pool));
2635        }
2636        update
2637            .remove_incoming_requests
2638            .push(requester_id.to_proto());
2639        for connection_id in pool.user_connection_ids(responder_id) {
2640            session.peer.send(connection_id, update.clone())?;
2641        }
2642
2643        // Update requester with new contact
2644        let mut update = proto::UpdateContacts::default();
2645        if accept {
2646            update
2647                .contacts
2648                .push(contact_for_user(responder_id, responder_busy, &pool));
2649        }
2650        update
2651            .remove_outgoing_requests
2652            .push(responder_id.to_proto());
2653
2654        for connection_id in pool.user_connection_ids(requester_id) {
2655            session.peer.send(connection_id, update.clone())?;
2656        }
2657
2658        send_notifications(&pool, &session.peer, notifications);
2659    }
2660
2661    response.send(proto::Ack {})?;
2662    Ok(())
2663}
2664
2665/// Remove a contact.
2666async fn remove_contact(
2667    request: proto::RemoveContact,
2668    response: Response<proto::RemoveContact>,
2669    session: Session,
2670) -> Result<()> {
2671    let requester_id = session.user_id();
2672    let responder_id = UserId::from_proto(request.user_id);
2673    let db = session.db().await;
2674    let (contact_accepted, deleted_notification_id) =
2675        db.remove_contact(requester_id, responder_id).await?;
2676
2677    let pool = session.connection_pool().await;
2678    // Update outgoing contact requests of requester
2679    let mut update = proto::UpdateContacts::default();
2680    if contact_accepted {
2681        update.remove_contacts.push(responder_id.to_proto());
2682    } else {
2683        update
2684            .remove_outgoing_requests
2685            .push(responder_id.to_proto());
2686    }
2687    for connection_id in pool.user_connection_ids(requester_id) {
2688        session.peer.send(connection_id, update.clone())?;
2689    }
2690
2691    // Update incoming contact requests of responder
2692    let mut update = proto::UpdateContacts::default();
2693    if contact_accepted {
2694        update.remove_contacts.push(requester_id.to_proto());
2695    } else {
2696        update
2697            .remove_incoming_requests
2698            .push(requester_id.to_proto());
2699    }
2700    for connection_id in pool.user_connection_ids(responder_id) {
2701        session.peer.send(connection_id, update.clone())?;
2702        if let Some(notification_id) = deleted_notification_id {
2703            session.peer.send(
2704                connection_id,
2705                proto::DeleteNotification {
2706                    notification_id: notification_id.to_proto(),
2707                },
2708            )?;
2709        }
2710    }
2711
2712    response.send(proto::Ack {})?;
2713    Ok(())
2714}
2715
2716fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2717    version.0.minor() < 139
2718}
2719
2720async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2721    if is_staff {
2722        return Ok(proto::Plan::ZedPro);
2723    }
2724
2725    let subscription = db.get_active_billing_subscription(user_id).await?;
2726    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2727
2728    let plan = if let Some(subscription_kind) = subscription_kind {
2729        match subscription_kind {
2730            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2731            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2732            SubscriptionKind::ZedFree => proto::Plan::Free,
2733        }
2734    } else {
2735        proto::Plan::Free
2736    };
2737
2738    Ok(plan)
2739}
2740
2741async fn make_update_user_plan_message(
2742    user: &User,
2743    is_staff: bool,
2744    db: &Arc<Database>,
2745    llm_db: Option<Arc<LlmDatabase>>,
2746) -> Result<proto::UpdateUserPlan> {
2747    let feature_flags = db.get_user_flags(user.id).await?;
2748    let plan = current_plan(db, user.id, is_staff).await?;
2749    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2750    let billing_preferences = db.get_billing_preferences(user.id).await?;
2751
2752    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2753        let subscription = db.get_active_billing_subscription(user.id).await?;
2754
2755        let subscription_period =
2756            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2757
2758        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2759            llm_db
2760                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2761                .await?
2762        } else {
2763            None
2764        };
2765
2766        (subscription_period, usage)
2767    } else {
2768        (None, None)
2769    };
2770
2771    let account_too_young =
2772        !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2773
2774    Ok(proto::UpdateUserPlan {
2775        plan: plan.into(),
2776        trial_started_at: billing_customer
2777            .as_ref()
2778            .and_then(|billing_customer| billing_customer.trial_started_at)
2779            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2780        is_usage_based_billing_enabled: if is_staff {
2781            Some(true)
2782        } else {
2783            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2784        },
2785        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2786            proto::SubscriptionPeriod {
2787                started_at: started_at.timestamp() as u64,
2788                ended_at: ended_at.timestamp() as u64,
2789            }
2790        }),
2791        account_too_young: Some(account_too_young),
2792        has_overdue_invoices: billing_customer
2793            .map(|billing_customer| billing_customer.has_overdue_invoices),
2794        usage: usage.map(|usage| {
2795            let plan = match plan {
2796                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2797                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2798                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2799            };
2800
2801            let model_requests_limit = match plan.model_requests_limit() {
2802                zed_llm_client::UsageLimit::Limited(limit) => {
2803                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2804                        && feature_flags
2805                            .iter()
2806                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2807                    {
2808                        1_000
2809                    } else {
2810                        limit
2811                    };
2812
2813                    zed_llm_client::UsageLimit::Limited(limit)
2814                }
2815                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2816            };
2817
2818            proto::SubscriptionUsage {
2819                model_requests_usage_amount: usage.model_requests as u32,
2820                model_requests_usage_limit: Some(proto::UsageLimit {
2821                    variant: Some(match model_requests_limit {
2822                        zed_llm_client::UsageLimit::Limited(limit) => {
2823                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2824                                limit: limit as u32,
2825                            })
2826                        }
2827                        zed_llm_client::UsageLimit::Unlimited => {
2828                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2829                        }
2830                    }),
2831                }),
2832                edit_predictions_usage_amount: usage.edit_predictions as u32,
2833                edit_predictions_usage_limit: Some(proto::UsageLimit {
2834                    variant: Some(match plan.edit_predictions_limit() {
2835                        zed_llm_client::UsageLimit::Limited(limit) => {
2836                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2837                                limit: limit as u32,
2838                            })
2839                        }
2840                        zed_llm_client::UsageLimit::Unlimited => {
2841                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2842                        }
2843                    }),
2844                }),
2845            }
2846        }),
2847    })
2848}
2849
2850async fn update_user_plan(session: &Session) -> Result<()> {
2851    let db = session.db().await;
2852
2853    let update_user_plan = make_update_user_plan_message(
2854        session.principal.user(),
2855        session.is_staff(),
2856        &db.0,
2857        session.app_state.llm_db.clone(),
2858    )
2859    .await?;
2860
2861    session
2862        .peer
2863        .send(session.connection_id, update_user_plan)
2864        .trace_err();
2865
2866    Ok(())
2867}
2868
2869async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2870    subscribe_user_to_channels(session.user_id(), &session).await?;
2871    Ok(())
2872}
2873
2874async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2875    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2876    let mut pool = session.connection_pool().await;
2877    for membership in &channels_for_user.channel_memberships {
2878        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2879    }
2880    session.peer.send(
2881        session.connection_id,
2882        build_update_user_channels(&channels_for_user),
2883    )?;
2884    session.peer.send(
2885        session.connection_id,
2886        build_channels_update(channels_for_user),
2887    )?;
2888    Ok(())
2889}
2890
2891/// Creates a new channel.
2892async fn create_channel(
2893    request: proto::CreateChannel,
2894    response: Response<proto::CreateChannel>,
2895    session: Session,
2896) -> Result<()> {
2897    let db = session.db().await;
2898
2899    let parent_id = request.parent_id.map(ChannelId::from_proto);
2900    let (channel, membership) = db
2901        .create_channel(&request.name, parent_id, session.user_id())
2902        .await?;
2903
2904    let root_id = channel.root_id();
2905    let channel = Channel::from_model(channel);
2906
2907    response.send(proto::CreateChannelResponse {
2908        channel: Some(channel.to_proto()),
2909        parent_id: request.parent_id,
2910    })?;
2911
2912    let mut connection_pool = session.connection_pool().await;
2913    if let Some(membership) = membership {
2914        connection_pool.subscribe_to_channel(
2915            membership.user_id,
2916            membership.channel_id,
2917            membership.role,
2918        );
2919        let update = proto::UpdateUserChannels {
2920            channel_memberships: vec![proto::ChannelMembership {
2921                channel_id: membership.channel_id.to_proto(),
2922                role: membership.role.into(),
2923            }],
2924            ..Default::default()
2925        };
2926        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2927            session.peer.send(connection_id, update.clone())?;
2928        }
2929    }
2930
2931    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2932        if !role.can_see_channel(channel.visibility) {
2933            continue;
2934        }
2935
2936        let update = proto::UpdateChannels {
2937            channels: vec![channel.to_proto()],
2938            ..Default::default()
2939        };
2940        session.peer.send(connection_id, update.clone())?;
2941    }
2942
2943    Ok(())
2944}
2945
2946/// Delete a channel
2947async fn delete_channel(
2948    request: proto::DeleteChannel,
2949    response: Response<proto::DeleteChannel>,
2950    session: Session,
2951) -> Result<()> {
2952    let db = session.db().await;
2953
2954    let channel_id = request.channel_id;
2955    let (root_channel, removed_channels) = db
2956        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2957        .await?;
2958    response.send(proto::Ack {})?;
2959
2960    // Notify members of removed channels
2961    let mut update = proto::UpdateChannels::default();
2962    update
2963        .delete_channels
2964        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2965
2966    let connection_pool = session.connection_pool().await;
2967    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2968        session.peer.send(connection_id, update.clone())?;
2969    }
2970
2971    Ok(())
2972}
2973
2974/// Invite someone to join a channel.
2975async fn invite_channel_member(
2976    request: proto::InviteChannelMember,
2977    response: Response<proto::InviteChannelMember>,
2978    session: Session,
2979) -> Result<()> {
2980    let db = session.db().await;
2981    let channel_id = ChannelId::from_proto(request.channel_id);
2982    let invitee_id = UserId::from_proto(request.user_id);
2983    let InviteMemberResult {
2984        channel,
2985        notifications,
2986    } = db
2987        .invite_channel_member(
2988            channel_id,
2989            invitee_id,
2990            session.user_id(),
2991            request.role().into(),
2992        )
2993        .await?;
2994
2995    let update = proto::UpdateChannels {
2996        channel_invitations: vec![channel.to_proto()],
2997        ..Default::default()
2998    };
2999
3000    let connection_pool = session.connection_pool().await;
3001    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3002        session.peer.send(connection_id, update.clone())?;
3003    }
3004
3005    send_notifications(&connection_pool, &session.peer, notifications);
3006
3007    response.send(proto::Ack {})?;
3008    Ok(())
3009}
3010
3011/// remove someone from a channel
3012async fn remove_channel_member(
3013    request: proto::RemoveChannelMember,
3014    response: Response<proto::RemoveChannelMember>,
3015    session: Session,
3016) -> Result<()> {
3017    let db = session.db().await;
3018    let channel_id = ChannelId::from_proto(request.channel_id);
3019    let member_id = UserId::from_proto(request.user_id);
3020
3021    let RemoveChannelMemberResult {
3022        membership_update,
3023        notification_id,
3024    } = db
3025        .remove_channel_member(channel_id, member_id, session.user_id())
3026        .await?;
3027
3028    let mut connection_pool = session.connection_pool().await;
3029    notify_membership_updated(
3030        &mut connection_pool,
3031        membership_update,
3032        member_id,
3033        &session.peer,
3034    );
3035    for connection_id in connection_pool.user_connection_ids(member_id) {
3036        if let Some(notification_id) = notification_id {
3037            session
3038                .peer
3039                .send(
3040                    connection_id,
3041                    proto::DeleteNotification {
3042                        notification_id: notification_id.to_proto(),
3043                    },
3044                )
3045                .trace_err();
3046        }
3047    }
3048
3049    response.send(proto::Ack {})?;
3050    Ok(())
3051}
3052
3053/// Toggle the channel between public and private.
3054/// Care is taken to maintain the invariant that public channels only descend from public channels,
3055/// (though members-only channels can appear at any point in the hierarchy).
3056async fn set_channel_visibility(
3057    request: proto::SetChannelVisibility,
3058    response: Response<proto::SetChannelVisibility>,
3059    session: Session,
3060) -> Result<()> {
3061    let db = session.db().await;
3062    let channel_id = ChannelId::from_proto(request.channel_id);
3063    let visibility = request.visibility().into();
3064
3065    let channel_model = db
3066        .set_channel_visibility(channel_id, visibility, session.user_id())
3067        .await?;
3068    let root_id = channel_model.root_id();
3069    let channel = Channel::from_model(channel_model);
3070
3071    let mut connection_pool = session.connection_pool().await;
3072    for (user_id, role) in connection_pool
3073        .channel_user_ids(root_id)
3074        .collect::<Vec<_>>()
3075        .into_iter()
3076    {
3077        let update = if role.can_see_channel(channel.visibility) {
3078            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3079            proto::UpdateChannels {
3080                channels: vec![channel.to_proto()],
3081                ..Default::default()
3082            }
3083        } else {
3084            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3085            proto::UpdateChannels {
3086                delete_channels: vec![channel.id.to_proto()],
3087                ..Default::default()
3088            }
3089        };
3090
3091        for connection_id in connection_pool.user_connection_ids(user_id) {
3092            session.peer.send(connection_id, update.clone())?;
3093        }
3094    }
3095
3096    response.send(proto::Ack {})?;
3097    Ok(())
3098}
3099
3100/// Alter the role for a user in the channel.
3101async fn set_channel_member_role(
3102    request: proto::SetChannelMemberRole,
3103    response: Response<proto::SetChannelMemberRole>,
3104    session: Session,
3105) -> Result<()> {
3106    let db = session.db().await;
3107    let channel_id = ChannelId::from_proto(request.channel_id);
3108    let member_id = UserId::from_proto(request.user_id);
3109    let result = db
3110        .set_channel_member_role(
3111            channel_id,
3112            session.user_id(),
3113            member_id,
3114            request.role().into(),
3115        )
3116        .await?;
3117
3118    match result {
3119        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3120            let mut connection_pool = session.connection_pool().await;
3121            notify_membership_updated(
3122                &mut connection_pool,
3123                membership_update,
3124                member_id,
3125                &session.peer,
3126            )
3127        }
3128        db::SetMemberRoleResult::InviteUpdated(channel) => {
3129            let update = proto::UpdateChannels {
3130                channel_invitations: vec![channel.to_proto()],
3131                ..Default::default()
3132            };
3133
3134            for connection_id in session
3135                .connection_pool()
3136                .await
3137                .user_connection_ids(member_id)
3138            {
3139                session.peer.send(connection_id, update.clone())?;
3140            }
3141        }
3142    }
3143
3144    response.send(proto::Ack {})?;
3145    Ok(())
3146}
3147
3148/// Change the name of a channel
3149async fn rename_channel(
3150    request: proto::RenameChannel,
3151    response: Response<proto::RenameChannel>,
3152    session: Session,
3153) -> Result<()> {
3154    let db = session.db().await;
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156    let channel_model = db
3157        .rename_channel(channel_id, session.user_id(), &request.name)
3158        .await?;
3159    let root_id = channel_model.root_id();
3160    let channel = Channel::from_model(channel_model);
3161
3162    response.send(proto::RenameChannelResponse {
3163        channel: Some(channel.to_proto()),
3164    })?;
3165
3166    let connection_pool = session.connection_pool().await;
3167    let update = proto::UpdateChannels {
3168        channels: vec![channel.to_proto()],
3169        ..Default::default()
3170    };
3171    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172        if role.can_see_channel(channel.visibility) {
3173            session.peer.send(connection_id, update.clone())?;
3174        }
3175    }
3176
3177    Ok(())
3178}
3179
3180/// Move a channel to a new parent.
3181async fn move_channel(
3182    request: proto::MoveChannel,
3183    response: Response<proto::MoveChannel>,
3184    session: Session,
3185) -> Result<()> {
3186    let channel_id = ChannelId::from_proto(request.channel_id);
3187    let to = ChannelId::from_proto(request.to);
3188
3189    let (root_id, channels) = session
3190        .db()
3191        .await
3192        .move_channel(channel_id, to, session.user_id())
3193        .await?;
3194
3195    let connection_pool = session.connection_pool().await;
3196    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3197        let channels = channels
3198            .iter()
3199            .filter_map(|channel| {
3200                if role.can_see_channel(channel.visibility) {
3201                    Some(channel.to_proto())
3202                } else {
3203                    None
3204                }
3205            })
3206            .collect::<Vec<_>>();
3207        if channels.is_empty() {
3208            continue;
3209        }
3210
3211        let update = proto::UpdateChannels {
3212            channels,
3213            ..Default::default()
3214        };
3215
3216        session.peer.send(connection_id, update.clone())?;
3217    }
3218
3219    response.send(Ack {})?;
3220    Ok(())
3221}
3222
3223/// Get the list of channel members
3224async fn get_channel_members(
3225    request: proto::GetChannelMembers,
3226    response: Response<proto::GetChannelMembers>,
3227    session: Session,
3228) -> Result<()> {
3229    let db = session.db().await;
3230    let channel_id = ChannelId::from_proto(request.channel_id);
3231    let limit = if request.limit == 0 {
3232        u16::MAX as u64
3233    } else {
3234        request.limit
3235    };
3236    let (members, users) = db
3237        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3238        .await?;
3239    response.send(proto::GetChannelMembersResponse { members, users })?;
3240    Ok(())
3241}
3242
3243/// Accept or decline a channel invitation.
3244async fn respond_to_channel_invite(
3245    request: proto::RespondToChannelInvite,
3246    response: Response<proto::RespondToChannelInvite>,
3247    session: Session,
3248) -> Result<()> {
3249    let db = session.db().await;
3250    let channel_id = ChannelId::from_proto(request.channel_id);
3251    let RespondToChannelInvite {
3252        membership_update,
3253        notifications,
3254    } = db
3255        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3256        .await?;
3257
3258    let mut connection_pool = session.connection_pool().await;
3259    if let Some(membership_update) = membership_update {
3260        notify_membership_updated(
3261            &mut connection_pool,
3262            membership_update,
3263            session.user_id(),
3264            &session.peer,
3265        );
3266    } else {
3267        let update = proto::UpdateChannels {
3268            remove_channel_invitations: vec![channel_id.to_proto()],
3269            ..Default::default()
3270        };
3271
3272        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3273            session.peer.send(connection_id, update.clone())?;
3274        }
3275    };
3276
3277    send_notifications(&connection_pool, &session.peer, notifications);
3278
3279    response.send(proto::Ack {})?;
3280
3281    Ok(())
3282}
3283
3284/// Join the channels' room
3285async fn join_channel(
3286    request: proto::JoinChannel,
3287    response: Response<proto::JoinChannel>,
3288    session: Session,
3289) -> Result<()> {
3290    let channel_id = ChannelId::from_proto(request.channel_id);
3291    join_channel_internal(channel_id, Box::new(response), session).await
3292}
3293
3294trait JoinChannelInternalResponse {
3295    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3296}
3297impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3298    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3299        Response::<proto::JoinChannel>::send(self, result)
3300    }
3301}
3302impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3303    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3304        Response::<proto::JoinRoom>::send(self, result)
3305    }
3306}
3307
3308async fn join_channel_internal(
3309    channel_id: ChannelId,
3310    response: Box<impl JoinChannelInternalResponse>,
3311    session: Session,
3312) -> Result<()> {
3313    let joined_room = {
3314        let mut db = session.db().await;
3315        // If zed quits without leaving the room, and the user re-opens zed before the
3316        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3317        // room they were in.
3318        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3319            tracing::info!(
3320                stale_connection_id = %connection,
3321                "cleaning up stale connection",
3322            );
3323            drop(db);
3324            leave_room_for_session(&session, connection).await?;
3325            db = session.db().await;
3326        }
3327
3328        let (joined_room, membership_updated, role) = db
3329            .join_channel(channel_id, session.user_id(), session.connection_id)
3330            .await?;
3331
3332        let live_kit_connection_info =
3333            session
3334                .app_state
3335                .livekit_client
3336                .as_ref()
3337                .and_then(|live_kit| {
3338                    let (can_publish, token) = if role == ChannelRole::Guest {
3339                        (
3340                            false,
3341                            live_kit
3342                                .guest_token(
3343                                    &joined_room.room.livekit_room,
3344                                    &session.user_id().to_string(),
3345                                )
3346                                .trace_err()?,
3347                        )
3348                    } else {
3349                        (
3350                            true,
3351                            live_kit
3352                                .room_token(
3353                                    &joined_room.room.livekit_room,
3354                                    &session.user_id().to_string(),
3355                                )
3356                                .trace_err()?,
3357                        )
3358                    };
3359
3360                    Some(LiveKitConnectionInfo {
3361                        server_url: live_kit.url().into(),
3362                        token,
3363                        can_publish,
3364                    })
3365                });
3366
3367        response.send(proto::JoinRoomResponse {
3368            room: Some(joined_room.room.clone()),
3369            channel_id: joined_room
3370                .channel
3371                .as_ref()
3372                .map(|channel| channel.id.to_proto()),
3373            live_kit_connection_info,
3374        })?;
3375
3376        let mut connection_pool = session.connection_pool().await;
3377        if let Some(membership_updated) = membership_updated {
3378            notify_membership_updated(
3379                &mut connection_pool,
3380                membership_updated,
3381                session.user_id(),
3382                &session.peer,
3383            );
3384        }
3385
3386        room_updated(&joined_room.room, &session.peer);
3387
3388        joined_room
3389    };
3390
3391    channel_updated(
3392        &joined_room.channel.context("channel not returned")?,
3393        &joined_room.room,
3394        &session.peer,
3395        &*session.connection_pool().await,
3396    );
3397
3398    update_user_contacts(session.user_id(), &session).await?;
3399    Ok(())
3400}
3401
3402/// Start editing the channel notes
3403async fn join_channel_buffer(
3404    request: proto::JoinChannelBuffer,
3405    response: Response<proto::JoinChannelBuffer>,
3406    session: Session,
3407) -> Result<()> {
3408    let db = session.db().await;
3409    let channel_id = ChannelId::from_proto(request.channel_id);
3410
3411    let open_response = db
3412        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3413        .await?;
3414
3415    let collaborators = open_response.collaborators.clone();
3416    response.send(open_response)?;
3417
3418    let update = UpdateChannelBufferCollaborators {
3419        channel_id: channel_id.to_proto(),
3420        collaborators: collaborators.clone(),
3421    };
3422    channel_buffer_updated(
3423        session.connection_id,
3424        collaborators
3425            .iter()
3426            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3427        &update,
3428        &session.peer,
3429    );
3430
3431    Ok(())
3432}
3433
3434/// Edit the channel notes
3435async fn update_channel_buffer(
3436    request: proto::UpdateChannelBuffer,
3437    session: Session,
3438) -> Result<()> {
3439    let db = session.db().await;
3440    let channel_id = ChannelId::from_proto(request.channel_id);
3441
3442    let (collaborators, epoch, version) = db
3443        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3444        .await?;
3445
3446    channel_buffer_updated(
3447        session.connection_id,
3448        collaborators.clone(),
3449        &proto::UpdateChannelBuffer {
3450            channel_id: channel_id.to_proto(),
3451            operations: request.operations,
3452        },
3453        &session.peer,
3454    );
3455
3456    let pool = &*session.connection_pool().await;
3457
3458    let non_collaborators =
3459        pool.channel_connection_ids(channel_id)
3460            .filter_map(|(connection_id, _)| {
3461                if collaborators.contains(&connection_id) {
3462                    None
3463                } else {
3464                    Some(connection_id)
3465                }
3466            });
3467
3468    broadcast(None, non_collaborators, |peer_id| {
3469        session.peer.send(
3470            peer_id,
3471            proto::UpdateChannels {
3472                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3473                    channel_id: channel_id.to_proto(),
3474                    epoch: epoch as u64,
3475                    version: version.clone(),
3476                }],
3477                ..Default::default()
3478            },
3479        )
3480    });
3481
3482    Ok(())
3483}
3484
3485/// Rejoin the channel notes after a connection blip
3486async fn rejoin_channel_buffers(
3487    request: proto::RejoinChannelBuffers,
3488    response: Response<proto::RejoinChannelBuffers>,
3489    session: Session,
3490) -> Result<()> {
3491    let db = session.db().await;
3492    let buffers = db
3493        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3494        .await?;
3495
3496    for rejoined_buffer in &buffers {
3497        let collaborators_to_notify = rejoined_buffer
3498            .buffer
3499            .collaborators
3500            .iter()
3501            .filter_map(|c| Some(c.peer_id?.into()));
3502        channel_buffer_updated(
3503            session.connection_id,
3504            collaborators_to_notify,
3505            &proto::UpdateChannelBufferCollaborators {
3506                channel_id: rejoined_buffer.buffer.channel_id,
3507                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3508            },
3509            &session.peer,
3510        );
3511    }
3512
3513    response.send(proto::RejoinChannelBuffersResponse {
3514        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3515    })?;
3516
3517    Ok(())
3518}
3519
3520/// Stop editing the channel notes
3521async fn leave_channel_buffer(
3522    request: proto::LeaveChannelBuffer,
3523    response: Response<proto::LeaveChannelBuffer>,
3524    session: Session,
3525) -> Result<()> {
3526    let db = session.db().await;
3527    let channel_id = ChannelId::from_proto(request.channel_id);
3528
3529    let left_buffer = db
3530        .leave_channel_buffer(channel_id, session.connection_id)
3531        .await?;
3532
3533    response.send(Ack {})?;
3534
3535    channel_buffer_updated(
3536        session.connection_id,
3537        left_buffer.connections,
3538        &proto::UpdateChannelBufferCollaborators {
3539            channel_id: channel_id.to_proto(),
3540            collaborators: left_buffer.collaborators,
3541        },
3542        &session.peer,
3543    );
3544
3545    Ok(())
3546}
3547
3548fn channel_buffer_updated<T: EnvelopedMessage>(
3549    sender_id: ConnectionId,
3550    collaborators: impl IntoIterator<Item = ConnectionId>,
3551    message: &T,
3552    peer: &Peer,
3553) {
3554    broadcast(Some(sender_id), collaborators, |peer_id| {
3555        peer.send(peer_id, message.clone())
3556    });
3557}
3558
3559fn send_notifications(
3560    connection_pool: &ConnectionPool,
3561    peer: &Peer,
3562    notifications: db::NotificationBatch,
3563) {
3564    for (user_id, notification) in notifications {
3565        for connection_id in connection_pool.user_connection_ids(user_id) {
3566            if let Err(error) = peer.send(
3567                connection_id,
3568                proto::AddNotification {
3569                    notification: Some(notification.clone()),
3570                },
3571            ) {
3572                tracing::error!(
3573                    "failed to send notification to {:?} {}",
3574                    connection_id,
3575                    error
3576                );
3577            }
3578        }
3579    }
3580}
3581
3582/// Send a message to the channel
3583async fn send_channel_message(
3584    request: proto::SendChannelMessage,
3585    response: Response<proto::SendChannelMessage>,
3586    session: Session,
3587) -> Result<()> {
3588    // Validate the message body.
3589    let body = request.body.trim().to_string();
3590    if body.len() > MAX_MESSAGE_LEN {
3591        return Err(anyhow!("message is too long"))?;
3592    }
3593    if body.is_empty() {
3594        return Err(anyhow!("message can't be blank"))?;
3595    }
3596
3597    // TODO: adjust mentions if body is trimmed
3598
3599    let timestamp = OffsetDateTime::now_utc();
3600    let nonce = request.nonce.context("nonce can't be blank")?;
3601
3602    let channel_id = ChannelId::from_proto(request.channel_id);
3603    let CreatedChannelMessage {
3604        message_id,
3605        participant_connection_ids,
3606        notifications,
3607    } = session
3608        .db()
3609        .await
3610        .create_channel_message(
3611            channel_id,
3612            session.user_id(),
3613            &body,
3614            &request.mentions,
3615            timestamp,
3616            nonce.clone().into(),
3617            request.reply_to_message_id.map(MessageId::from_proto),
3618        )
3619        .await?;
3620
3621    let message = proto::ChannelMessage {
3622        sender_id: session.user_id().to_proto(),
3623        id: message_id.to_proto(),
3624        body,
3625        mentions: request.mentions,
3626        timestamp: timestamp.unix_timestamp() as u64,
3627        nonce: Some(nonce),
3628        reply_to_message_id: request.reply_to_message_id,
3629        edited_at: None,
3630    };
3631    broadcast(
3632        Some(session.connection_id),
3633        participant_connection_ids.clone(),
3634        |connection| {
3635            session.peer.send(
3636                connection,
3637                proto::ChannelMessageSent {
3638                    channel_id: channel_id.to_proto(),
3639                    message: Some(message.clone()),
3640                },
3641            )
3642        },
3643    );
3644    response.send(proto::SendChannelMessageResponse {
3645        message: Some(message),
3646    })?;
3647
3648    let pool = &*session.connection_pool().await;
3649    let non_participants =
3650        pool.channel_connection_ids(channel_id)
3651            .filter_map(|(connection_id, _)| {
3652                if participant_connection_ids.contains(&connection_id) {
3653                    None
3654                } else {
3655                    Some(connection_id)
3656                }
3657            });
3658    broadcast(None, non_participants, |peer_id| {
3659        session.peer.send(
3660            peer_id,
3661            proto::UpdateChannels {
3662                latest_channel_message_ids: vec![proto::ChannelMessageId {
3663                    channel_id: channel_id.to_proto(),
3664                    message_id: message_id.to_proto(),
3665                }],
3666                ..Default::default()
3667            },
3668        )
3669    });
3670    send_notifications(pool, &session.peer, notifications);
3671
3672    Ok(())
3673}
3674
3675/// Delete a channel message
3676async fn remove_channel_message(
3677    request: proto::RemoveChannelMessage,
3678    response: Response<proto::RemoveChannelMessage>,
3679    session: Session,
3680) -> Result<()> {
3681    let channel_id = ChannelId::from_proto(request.channel_id);
3682    let message_id = MessageId::from_proto(request.message_id);
3683    let (connection_ids, existing_notification_ids) = session
3684        .db()
3685        .await
3686        .remove_channel_message(channel_id, message_id, session.user_id())
3687        .await?;
3688
3689    broadcast(
3690        Some(session.connection_id),
3691        connection_ids,
3692        move |connection| {
3693            session.peer.send(connection, request.clone())?;
3694
3695            for notification_id in &existing_notification_ids {
3696                session.peer.send(
3697                    connection,
3698                    proto::DeleteNotification {
3699                        notification_id: (*notification_id).to_proto(),
3700                    },
3701                )?;
3702            }
3703
3704            Ok(())
3705        },
3706    );
3707    response.send(proto::Ack {})?;
3708    Ok(())
3709}
3710
3711async fn update_channel_message(
3712    request: proto::UpdateChannelMessage,
3713    response: Response<proto::UpdateChannelMessage>,
3714    session: Session,
3715) -> Result<()> {
3716    let channel_id = ChannelId::from_proto(request.channel_id);
3717    let message_id = MessageId::from_proto(request.message_id);
3718    let updated_at = OffsetDateTime::now_utc();
3719    let UpdatedChannelMessage {
3720        message_id,
3721        participant_connection_ids,
3722        notifications,
3723        reply_to_message_id,
3724        timestamp,
3725        deleted_mention_notification_ids,
3726        updated_mention_notifications,
3727    } = session
3728        .db()
3729        .await
3730        .update_channel_message(
3731            channel_id,
3732            message_id,
3733            session.user_id(),
3734            request.body.as_str(),
3735            &request.mentions,
3736            updated_at,
3737        )
3738        .await?;
3739
3740    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3741
3742    let message = proto::ChannelMessage {
3743        sender_id: session.user_id().to_proto(),
3744        id: message_id.to_proto(),
3745        body: request.body.clone(),
3746        mentions: request.mentions.clone(),
3747        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3748        nonce: Some(nonce),
3749        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3750        edited_at: Some(updated_at.unix_timestamp() as u64),
3751    };
3752
3753    response.send(proto::Ack {})?;
3754
3755    let pool = &*session.connection_pool().await;
3756    broadcast(
3757        Some(session.connection_id),
3758        participant_connection_ids,
3759        |connection| {
3760            session.peer.send(
3761                connection,
3762                proto::ChannelMessageUpdate {
3763                    channel_id: channel_id.to_proto(),
3764                    message: Some(message.clone()),
3765                },
3766            )?;
3767
3768            for notification_id in &deleted_mention_notification_ids {
3769                session.peer.send(
3770                    connection,
3771                    proto::DeleteNotification {
3772                        notification_id: (*notification_id).to_proto(),
3773                    },
3774                )?;
3775            }
3776
3777            for notification in &updated_mention_notifications {
3778                session.peer.send(
3779                    connection,
3780                    proto::UpdateNotification {
3781                        notification: Some(notification.clone()),
3782                    },
3783                )?;
3784            }
3785
3786            Ok(())
3787        },
3788    );
3789
3790    send_notifications(pool, &session.peer, notifications);
3791
3792    Ok(())
3793}
3794
3795/// Mark a channel message as read
3796async fn acknowledge_channel_message(
3797    request: proto::AckChannelMessage,
3798    session: Session,
3799) -> Result<()> {
3800    let channel_id = ChannelId::from_proto(request.channel_id);
3801    let message_id = MessageId::from_proto(request.message_id);
3802    let notifications = session
3803        .db()
3804        .await
3805        .observe_channel_message(channel_id, session.user_id(), message_id)
3806        .await?;
3807    send_notifications(
3808        &*session.connection_pool().await,
3809        &session.peer,
3810        notifications,
3811    );
3812    Ok(())
3813}
3814
3815/// Mark a buffer version as synced
3816async fn acknowledge_buffer_version(
3817    request: proto::AckBufferOperation,
3818    session: Session,
3819) -> Result<()> {
3820    let buffer_id = BufferId::from_proto(request.buffer_id);
3821    session
3822        .db()
3823        .await
3824        .observe_buffer_version(
3825            buffer_id,
3826            session.user_id(),
3827            request.epoch as i32,
3828            &request.version,
3829        )
3830        .await?;
3831    Ok(())
3832}
3833
3834/// Get a Supermaven API key for the user
3835async fn get_supermaven_api_key(
3836    _request: proto::GetSupermavenApiKey,
3837    response: Response<proto::GetSupermavenApiKey>,
3838    session: Session,
3839) -> Result<()> {
3840    let user_id: String = session.user_id().to_string();
3841    if !session.is_staff() {
3842        return Err(anyhow!("supermaven not enabled for this account"))?;
3843    }
3844
3845    let email = session.email().context("user must have an email")?;
3846
3847    let supermaven_admin_api = session
3848        .supermaven_client
3849        .as_ref()
3850        .context("supermaven not configured")?;
3851
3852    let result = supermaven_admin_api
3853        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3854        .await?;
3855
3856    response.send(proto::GetSupermavenApiKeyResponse {
3857        api_key: result.api_key,
3858    })?;
3859
3860    Ok(())
3861}
3862
3863/// Start receiving chat updates for a channel
3864async fn join_channel_chat(
3865    request: proto::JoinChannelChat,
3866    response: Response<proto::JoinChannelChat>,
3867    session: Session,
3868) -> Result<()> {
3869    let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871    let db = session.db().await;
3872    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3873        .await?;
3874    let messages = db
3875        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3876        .await?;
3877    response.send(proto::JoinChannelChatResponse {
3878        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3879        messages,
3880    })?;
3881    Ok(())
3882}
3883
3884/// Stop receiving chat updates for a channel
3885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3886    let channel_id = ChannelId::from_proto(request.channel_id);
3887    session
3888        .db()
3889        .await
3890        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891        .await?;
3892    Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897    request: proto::GetChannelMessages,
3898    response: Response<proto::GetChannelMessages>,
3899    session: Session,
3900) -> Result<()> {
3901    let channel_id = ChannelId::from_proto(request.channel_id);
3902    let messages = session
3903        .db()
3904        .await
3905        .get_channel_messages(
3906            channel_id,
3907            session.user_id(),
3908            MESSAGE_COUNT_PER_PAGE,
3909            Some(MessageId::from_proto(request.before_message_id)),
3910        )
3911        .await?;
3912    response.send(proto::GetChannelMessagesResponse {
3913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914        messages,
3915    })?;
3916    Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921    request: proto::GetChannelMessagesById,
3922    response: Response<proto::GetChannelMessagesById>,
3923    session: Session,
3924) -> Result<()> {
3925    let message_ids = request
3926        .message_ids
3927        .iter()
3928        .map(|id| MessageId::from_proto(*id))
3929        .collect::<Vec<_>>();
3930    let messages = session
3931        .db()
3932        .await
3933        .get_channel_messages_by_id(session.user_id(), &message_ids)
3934        .await?;
3935    response.send(proto::GetChannelMessagesResponse {
3936        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937        messages,
3938    })?;
3939    Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944    request: proto::GetNotifications,
3945    response: Response<proto::GetNotifications>,
3946    session: Session,
3947) -> Result<()> {
3948    let notifications = session
3949        .db()
3950        .await
3951        .get_notifications(
3952            session.user_id(),
3953            NOTIFICATION_COUNT_PER_PAGE,
3954            request.before_id.map(db::NotificationId::from_proto),
3955        )
3956        .await?;
3957    response.send(proto::GetNotificationsResponse {
3958        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959        notifications,
3960    })?;
3961    Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966    request: proto::MarkNotificationRead,
3967    response: Response<proto::MarkNotificationRead>,
3968    session: Session,
3969) -> Result<()> {
3970    let database = &session.db().await;
3971    let notifications = database
3972        .mark_notification_as_read_by_id(
3973            session.user_id(),
3974            NotificationId::from_proto(request.notification_id),
3975        )
3976        .await?;
3977    send_notifications(
3978        &*session.connection_pool().await,
3979        &session.peer,
3980        notifications,
3981    );
3982    response.send(proto::Ack {})?;
3983    Ok(())
3984}
3985
3986/// Get the current users information
3987async fn get_private_user_info(
3988    _request: proto::GetPrivateUserInfo,
3989    response: Response<proto::GetPrivateUserInfo>,
3990    session: Session,
3991) -> Result<()> {
3992    let db = session.db().await;
3993
3994    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3995    let user = db
3996        .get_user_by_id(session.user_id())
3997        .await?
3998        .context("user not found")?;
3999    let flags = db.get_user_flags(session.user_id()).await?;
4000
4001    response.send(proto::GetPrivateUserInfoResponse {
4002        metrics_id,
4003        staff: user.admin,
4004        flags,
4005        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4006    })?;
4007    Ok(())
4008}
4009
4010/// Accept the terms of service (tos) on behalf of the current user
4011async fn accept_terms_of_service(
4012    _request: proto::AcceptTermsOfService,
4013    response: Response<proto::AcceptTermsOfService>,
4014    session: Session,
4015) -> Result<()> {
4016    let db = session.db().await;
4017
4018    let accepted_tos_at = Utc::now();
4019    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4020        .await?;
4021
4022    response.send(proto::AcceptTermsOfServiceResponse {
4023        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4024    })?;
4025    Ok(())
4026}
4027
4028/// The minimum account age an account must have in order to use the LLM service.
4029pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4030
4031async fn get_llm_api_token(
4032    _request: proto::GetLlmToken,
4033    response: Response<proto::GetLlmToken>,
4034    session: Session,
4035) -> Result<()> {
4036    let db = session.db().await;
4037
4038    let flags = db.get_user_flags(session.user_id()).await?;
4039
4040    let user_id = session.user_id();
4041    let user = db
4042        .get_user_by_id(user_id)
4043        .await?
4044        .with_context(|| format!("user {user_id} not found"))?;
4045
4046    if user.accepted_tos_at.is_none() {
4047        Err(anyhow!("terms of service not accepted"))?
4048    }
4049
4050    let stripe_client = session
4051        .app_state
4052        .stripe_client
4053        .as_ref()
4054        .context("failed to retrieve Stripe client")?;
4055
4056    let stripe_billing = session
4057        .app_state
4058        .stripe_billing
4059        .as_ref()
4060        .context("failed to retrieve Stripe billing object")?;
4061
4062    let billing_customer = if let Some(billing_customer) =
4063        db.get_billing_customer_by_user_id(user.id).await?
4064    {
4065        billing_customer
4066    } else {
4067        let customer_id = stripe_billing
4068            .find_or_create_customer_by_email(user.email_address.as_deref())
4069            .await?;
4070
4071        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4072            .await?
4073            .context("billing customer not found")?
4074    };
4075
4076    let billing_subscription =
4077        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4078            billing_subscription
4079        } else {
4080            let stripe_customer_id =
4081                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4082
4083            let stripe_subscription = stripe_billing
4084                .subscribe_to_zed_free(stripe_customer_id)
4085                .await?;
4086
4087            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4088                billing_customer_id: billing_customer.id,
4089                kind: Some(SubscriptionKind::ZedFree),
4090                stripe_subscription_id: stripe_subscription.id.to_string(),
4091                stripe_subscription_status: stripe_subscription.status.into(),
4092                stripe_cancellation_reason: None,
4093                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4094                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4095            })
4096            .await?
4097        };
4098
4099    let billing_preferences = db.get_billing_preferences(user.id).await?;
4100
4101    let token = LlmTokenClaims::create(
4102        &user,
4103        session.is_staff(),
4104        billing_customer,
4105        billing_preferences,
4106        &flags,
4107        billing_subscription,
4108        session.system_id.clone(),
4109        &session.app_state.config,
4110    )?;
4111    response.send(proto::GetLlmTokenResponse { token })?;
4112    Ok(())
4113}
4114
4115fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4116    let message = match message {
4117        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4118        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4119        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4120        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4121        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4122            code: frame.code.into(),
4123            reason: frame.reason.as_str().to_owned().into(),
4124        })),
4125        // We should never receive a frame while reading the message, according
4126        // to the `tungstenite` maintainers:
4127        //
4128        // > It cannot occur when you read messages from the WebSocket, but it
4129        // > can be used when you want to send the raw frames (e.g. you want to
4130        // > send the frames to the WebSocket without composing the full message first).
4131        // >
4132        // > — https://github.com/snapview/tungstenite-rs/issues/268
4133        TungsteniteMessage::Frame(_) => {
4134            bail!("received an unexpected frame while reading the message")
4135        }
4136    };
4137
4138    Ok(message)
4139}
4140
4141fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4142    match message {
4143        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4144        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4145        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4146        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4147        AxumMessage::Close(frame) => {
4148            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4149                code: frame.code.into(),
4150                reason: frame.reason.as_ref().into(),
4151            }))
4152        }
4153    }
4154}
4155
4156fn notify_membership_updated(
4157    connection_pool: &mut ConnectionPool,
4158    result: MembershipUpdated,
4159    user_id: UserId,
4160    peer: &Peer,
4161) {
4162    for membership in &result.new_channels.channel_memberships {
4163        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4164    }
4165    for channel_id in &result.removed_channels {
4166        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4167    }
4168
4169    let user_channels_update = proto::UpdateUserChannels {
4170        channel_memberships: result
4171            .new_channels
4172            .channel_memberships
4173            .iter()
4174            .map(|cm| proto::ChannelMembership {
4175                channel_id: cm.channel_id.to_proto(),
4176                role: cm.role.into(),
4177            })
4178            .collect(),
4179        ..Default::default()
4180    };
4181
4182    let mut update = build_channels_update(result.new_channels);
4183    update.delete_channels = result
4184        .removed_channels
4185        .into_iter()
4186        .map(|id| id.to_proto())
4187        .collect();
4188    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4189
4190    for connection_id in connection_pool.user_connection_ids(user_id) {
4191        peer.send(connection_id, user_channels_update.clone())
4192            .trace_err();
4193        peer.send(connection_id, update.clone()).trace_err();
4194    }
4195}
4196
4197fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4198    proto::UpdateUserChannels {
4199        channel_memberships: channels
4200            .channel_memberships
4201            .iter()
4202            .map(|m| proto::ChannelMembership {
4203                channel_id: m.channel_id.to_proto(),
4204                role: m.role.into(),
4205            })
4206            .collect(),
4207        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4208        observed_channel_message_id: channels.observed_channel_messages.clone(),
4209    }
4210}
4211
4212fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4213    let mut update = proto::UpdateChannels::default();
4214
4215    for channel in channels.channels {
4216        update.channels.push(channel.to_proto());
4217    }
4218
4219    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4220    update.latest_channel_message_ids = channels.latest_channel_messages;
4221
4222    for (channel_id, participants) in channels.channel_participants {
4223        update
4224            .channel_participants
4225            .push(proto::ChannelParticipants {
4226                channel_id: channel_id.to_proto(),
4227                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4228            });
4229    }
4230
4231    for channel in channels.invited_channels {
4232        update.channel_invitations.push(channel.to_proto());
4233    }
4234
4235    update
4236}
4237
4238fn build_initial_contacts_update(
4239    contacts: Vec<db::Contact>,
4240    pool: &ConnectionPool,
4241) -> proto::UpdateContacts {
4242    let mut update = proto::UpdateContacts::default();
4243
4244    for contact in contacts {
4245        match contact {
4246            db::Contact::Accepted { user_id, busy } => {
4247                update.contacts.push(contact_for_user(user_id, busy, pool));
4248            }
4249            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4250            db::Contact::Incoming { user_id } => {
4251                update
4252                    .incoming_requests
4253                    .push(proto::IncomingContactRequest {
4254                        requester_id: user_id.to_proto(),
4255                    })
4256            }
4257        }
4258    }
4259
4260    update
4261}
4262
4263fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4264    proto::Contact {
4265        user_id: user_id.to_proto(),
4266        online: pool.is_user_online(user_id),
4267        busy,
4268    }
4269}
4270
4271fn room_updated(room: &proto::Room, peer: &Peer) {
4272    broadcast(
4273        None,
4274        room.participants
4275            .iter()
4276            .filter_map(|participant| Some(participant.peer_id?.into())),
4277        |peer_id| {
4278            peer.send(
4279                peer_id,
4280                proto::RoomUpdated {
4281                    room: Some(room.clone()),
4282                },
4283            )
4284        },
4285    );
4286}
4287
4288fn channel_updated(
4289    channel: &db::channel::Model,
4290    room: &proto::Room,
4291    peer: &Peer,
4292    pool: &ConnectionPool,
4293) {
4294    let participants = room
4295        .participants
4296        .iter()
4297        .map(|p| p.user_id)
4298        .collect::<Vec<_>>();
4299
4300    broadcast(
4301        None,
4302        pool.channel_connection_ids(channel.root_id())
4303            .filter_map(|(channel_id, role)| {
4304                role.can_see_channel(channel.visibility)
4305                    .then_some(channel_id)
4306            }),
4307        |peer_id| {
4308            peer.send(
4309                peer_id,
4310                proto::UpdateChannels {
4311                    channel_participants: vec![proto::ChannelParticipants {
4312                        channel_id: channel.id.to_proto(),
4313                        participant_user_ids: participants.clone(),
4314                    }],
4315                    ..Default::default()
4316                },
4317            )
4318        },
4319    );
4320}
4321
4322async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4323    let db = session.db().await;
4324
4325    let contacts = db.get_contacts(user_id).await?;
4326    let busy = db.is_user_busy(user_id).await?;
4327
4328    let pool = session.connection_pool().await;
4329    let updated_contact = contact_for_user(user_id, busy, &pool);
4330    for contact in contacts {
4331        if let db::Contact::Accepted {
4332            user_id: contact_user_id,
4333            ..
4334        } = contact
4335        {
4336            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4337                session
4338                    .peer
4339                    .send(
4340                        contact_conn_id,
4341                        proto::UpdateContacts {
4342                            contacts: vec![updated_contact.clone()],
4343                            remove_contacts: Default::default(),
4344                            incoming_requests: Default::default(),
4345                            remove_incoming_requests: Default::default(),
4346                            outgoing_requests: Default::default(),
4347                            remove_outgoing_requests: Default::default(),
4348                        },
4349                    )
4350                    .trace_err();
4351            }
4352        }
4353    }
4354    Ok(())
4355}
4356
4357async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4358    let mut contacts_to_update = HashSet::default();
4359
4360    let room_id;
4361    let canceled_calls_to_user_ids;
4362    let livekit_room;
4363    let delete_livekit_room;
4364    let room;
4365    let channel;
4366
4367    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4368        contacts_to_update.insert(session.user_id());
4369
4370        for project in left_room.left_projects.values() {
4371            project_left(project, session);
4372        }
4373
4374        room_id = RoomId::from_proto(left_room.room.id);
4375        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4376        livekit_room = mem::take(&mut left_room.room.livekit_room);
4377        delete_livekit_room = left_room.deleted;
4378        room = mem::take(&mut left_room.room);
4379        channel = mem::take(&mut left_room.channel);
4380
4381        room_updated(&room, &session.peer);
4382    } else {
4383        return Ok(());
4384    }
4385
4386    if let Some(channel) = channel {
4387        channel_updated(
4388            &channel,
4389            &room,
4390            &session.peer,
4391            &*session.connection_pool().await,
4392        );
4393    }
4394
4395    {
4396        let pool = session.connection_pool().await;
4397        for canceled_user_id in canceled_calls_to_user_ids {
4398            for connection_id in pool.user_connection_ids(canceled_user_id) {
4399                session
4400                    .peer
4401                    .send(
4402                        connection_id,
4403                        proto::CallCanceled {
4404                            room_id: room_id.to_proto(),
4405                        },
4406                    )
4407                    .trace_err();
4408            }
4409            contacts_to_update.insert(canceled_user_id);
4410        }
4411    }
4412
4413    for contact_user_id in contacts_to_update {
4414        update_user_contacts(contact_user_id, session).await?;
4415    }
4416
4417    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4418        live_kit
4419            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4420            .await
4421            .trace_err();
4422
4423        if delete_livekit_room {
4424            live_kit.delete_room(livekit_room).await.trace_err();
4425        }
4426    }
4427
4428    Ok(())
4429}
4430
4431async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4432    let left_channel_buffers = session
4433        .db()
4434        .await
4435        .leave_channel_buffers(session.connection_id)
4436        .await?;
4437
4438    for left_buffer in left_channel_buffers {
4439        channel_buffer_updated(
4440            session.connection_id,
4441            left_buffer.connections,
4442            &proto::UpdateChannelBufferCollaborators {
4443                channel_id: left_buffer.channel_id.to_proto(),
4444                collaborators: left_buffer.collaborators,
4445            },
4446            &session.peer,
4447        );
4448    }
4449
4450    Ok(())
4451}
4452
4453fn project_left(project: &db::LeftProject, session: &Session) {
4454    for connection_id in &project.connection_ids {
4455        if project.should_unshare {
4456            session
4457                .peer
4458                .send(
4459                    *connection_id,
4460                    proto::UnshareProject {
4461                        project_id: project.id.to_proto(),
4462                    },
4463                )
4464                .trace_err();
4465        } else {
4466            session
4467                .peer
4468                .send(
4469                    *connection_id,
4470                    proto::RemoveProjectCollaborator {
4471                        project_id: project.id.to_proto(),
4472                        peer_id: Some(session.connection_id.into()),
4473                    },
4474                )
4475                .trace_err();
4476        }
4477    }
4478}
4479
4480pub trait ResultExt {
4481    type Ok;
4482
4483    fn trace_err(self) -> Option<Self::Ok>;
4484}
4485
4486impl<T, E> ResultExt for Result<T, E>
4487where
4488    E: std::fmt::Debug,
4489{
4490    type Ok = T;
4491
4492    #[track_caller]
4493    fn trace_err(self) -> Option<T> {
4494        match self {
4495            Ok(value) => Some(value),
4496            Err(error) => {
4497                tracing::error!("{:?}", error);
4498                None
4499            }
4500        }
4501    }
4502}