rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   6use crate::{
   7    AppState, Error, Result, auth,
   8    db::{
   9        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  10        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  11        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  12        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15};
  16use anyhow::{Context as _, anyhow, bail};
  17use async_tungstenite::tungstenite::{
  18    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  19};
  20use axum::{
  21    Extension, Router, TypedHeader,
  22    body::Body,
  23    extract::{
  24        ConnectInfo, WebSocketUpgrade,
  25        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{MutexGuard, Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone, Debug)]
 105pub enum Principal {
 106    User(User),
 107    Impersonated { user: User, admin: User },
 108}
 109
 110impl Principal {
 111    fn update_span(&self, span: &tracing::Span) {
 112        match &self {
 113            Principal::User(user) => {
 114                span.record("user_id", user.id.0);
 115                span.record("login", &user.github_login);
 116            }
 117            Principal::Impersonated { user, admin } => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120                span.record("impersonator", &admin.github_login);
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Clone)]
 127struct Session {
 128    principal: Principal,
 129    connection_id: ConnectionId,
 130    db: Arc<tokio::sync::Mutex<DbHandle>>,
 131    peer: Arc<Peer>,
 132    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 133    app_state: Arc<AppState>,
 134    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 135    /// The GeoIP country code for the user.
 136    #[allow(unused)]
 137    geoip_country_code: Option<String>,
 138    system_id: Option<String>,
 139    _executor: Executor,
 140}
 141
 142impl Session {
 143    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 144        #[cfg(test)]
 145        tokio::task::yield_now().await;
 146        let guard = self.db.lock().await;
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        guard
 150    }
 151
 152    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.connection_pool.lock();
 156        ConnectionPoolGuard {
 157            guard,
 158            _not_send: PhantomData,
 159        }
 160    }
 161
 162    fn is_staff(&self) -> bool {
 163        match &self.principal {
 164            Principal::User(user) => user.admin,
 165            Principal::Impersonated { .. } => true,
 166        }
 167    }
 168
 169    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 170        if self.is_staff() {
 171            return Ok(proto::Plan::ZedPro);
 172        }
 173
 174        let user_id = self.user_id();
 175
 176        let subscription = db.get_active_billing_subscription(user_id).await?;
 177        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
 178
 179        let plan = if let Some(subscription_kind) = subscription_kind {
 180            match subscription_kind {
 181                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
 182                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
 183                SubscriptionKind::ZedFree => proto::Plan::Free,
 184            }
 185        } else {
 186            proto::Plan::Free
 187        };
 188
 189        Ok(plan)
 190    }
 191
 192    fn user_id(&self) -> UserId {
 193        match &self.principal {
 194            Principal::User(user) => user.id,
 195            Principal::Impersonated { user, .. } => user.id,
 196        }
 197    }
 198
 199    pub fn email(&self) -> Option<String> {
 200        match &self.principal {
 201            Principal::User(user) => user.email_address.clone(),
 202            Principal::Impersonated { user, .. } => user.email_address.clone(),
 203        }
 204    }
 205}
 206
 207impl Debug for Session {
 208    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 209        let mut result = f.debug_struct("Session");
 210        match &self.principal {
 211            Principal::User(user) => {
 212                result.field("user", &user.github_login);
 213            }
 214            Principal::Impersonated { user, admin } => {
 215                result.field("user", &user.github_login);
 216                result.field("impersonator", &admin.github_login);
 217            }
 218        }
 219        result.field("connection_id", &self.connection_id).finish()
 220    }
 221}
 222
 223struct DbHandle(Arc<Database>);
 224
 225impl Deref for DbHandle {
 226    type Target = Database;
 227
 228    fn deref(&self) -> &Self::Target {
 229        self.0.as_ref()
 230    }
 231}
 232
 233pub struct Server {
 234    id: parking_lot::Mutex<ServerId>,
 235    peer: Arc<Peer>,
 236    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 237    app_state: Arc<AppState>,
 238    handlers: HashMap<TypeId, MessageHandler>,
 239    teardown: watch::Sender<bool>,
 240}
 241
 242pub(crate) struct ConnectionPoolGuard<'a> {
 243    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 244    _not_send: PhantomData<Rc<()>>,
 245}
 246
 247#[derive(Serialize)]
 248pub struct ServerSnapshot<'a> {
 249    peer: &'a Peer,
 250    #[serde(serialize_with = "serialize_deref")]
 251    connection_pool: ConnectionPoolGuard<'a>,
 252}
 253
 254pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 255where
 256    S: Serializer,
 257    T: Deref<Target = U>,
 258    U: Serialize,
 259{
 260    Serialize::serialize(value.deref(), serializer)
 261}
 262
 263impl Server {
 264    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 265        let mut server = Self {
 266            id: parking_lot::Mutex::new(id),
 267            peer: Peer::new(id.0 as u32),
 268            app_state: app_state.clone(),
 269            connection_pool: Default::default(),
 270            handlers: Default::default(),
 271            teardown: watch::channel(false).0,
 272        };
 273
 274        server
 275            .add_request_handler(ping)
 276            .add_request_handler(create_room)
 277            .add_request_handler(join_room)
 278            .add_request_handler(rejoin_room)
 279            .add_request_handler(leave_room)
 280            .add_request_handler(set_room_participant_role)
 281            .add_request_handler(call)
 282            .add_request_handler(cancel_call)
 283            .add_message_handler(decline_call)
 284            .add_request_handler(update_participant_location)
 285            .add_request_handler(share_project)
 286            .add_message_handler(unshare_project)
 287            .add_request_handler(join_project)
 288            .add_message_handler(leave_project)
 289            .add_request_handler(update_project)
 290            .add_request_handler(update_worktree)
 291            .add_request_handler(update_repository)
 292            .add_request_handler(remove_repository)
 293            .add_message_handler(start_language_server)
 294            .add_message_handler(update_language_server)
 295            .add_message_handler(update_diagnostic_summary)
 296            .add_message_handler(update_worktree_settings)
 297            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 301            .add_request_handler(forward_find_search_candidates_request)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 315            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 316            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 317            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 318            .add_request_handler(
 319                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 320            )
 321            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 322            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 323            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 325            .add_request_handler(
 326                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 327            )
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 341            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 346            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 351            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 356            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 357            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 358            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_request_handler(update_buffer)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 368            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 369            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 370            .add_request_handler(get_users)
 371            .add_request_handler(fuzzy_search_users)
 372            .add_request_handler(request_contact)
 373            .add_request_handler(remove_contact)
 374            .add_request_handler(respond_to_contact_request)
 375            .add_message_handler(subscribe_to_channels)
 376            .add_request_handler(create_channel)
 377            .add_request_handler(delete_channel)
 378            .add_request_handler(invite_channel_member)
 379            .add_request_handler(remove_channel_member)
 380            .add_request_handler(set_channel_member_role)
 381            .add_request_handler(set_channel_visibility)
 382            .add_request_handler(rename_channel)
 383            .add_request_handler(join_channel_buffer)
 384            .add_request_handler(leave_channel_buffer)
 385            .add_message_handler(update_channel_buffer)
 386            .add_request_handler(rejoin_channel_buffers)
 387            .add_request_handler(get_channel_members)
 388            .add_request_handler(respond_to_channel_invite)
 389            .add_request_handler(join_channel)
 390            .add_request_handler(join_channel_chat)
 391            .add_message_handler(leave_channel_chat)
 392            .add_request_handler(send_channel_message)
 393            .add_request_handler(remove_channel_message)
 394            .add_request_handler(update_channel_message)
 395            .add_request_handler(get_channel_messages)
 396            .add_request_handler(get_channel_messages_by_id)
 397            .add_request_handler(get_notifications)
 398            .add_request_handler(mark_notification_as_read)
 399            .add_request_handler(move_channel)
 400            .add_request_handler(follow)
 401            .add_message_handler(unfollow)
 402            .add_message_handler(update_followers)
 403            .add_request_handler(get_private_user_info)
 404            .add_request_handler(get_llm_api_token)
 405            .add_request_handler(accept_terms_of_service)
 406            .add_message_handler(acknowledge_channel_message)
 407            .add_message_handler(acknowledge_buffer_version)
 408            .add_request_handler(get_supermaven_api_key)
 409            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 410            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 411            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 413            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 415            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 416            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 417            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 418            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 419            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 420            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 421            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 422            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 424            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 425            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 426            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 427            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 428            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 429            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 430            .add_message_handler(update_context);
 431
 432        Arc::new(server)
 433    }
 434
 435    pub async fn start(&self) -> Result<()> {
 436        let server_id = *self.id.lock();
 437        let app_state = self.app_state.clone();
 438        let peer = self.peer.clone();
 439        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 440        let pool = self.connection_pool.clone();
 441        let livekit_client = self.app_state.livekit_client.clone();
 442
 443        let span = info_span!("start server");
 444        self.app_state.executor.spawn_detached(
 445            async move {
 446                tracing::info!("waiting for cleanup timeout");
 447                timeout.await;
 448                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 449                if let Some((room_ids, channel_ids)) = app_state
 450                    .db
 451                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 452                    .await
 453                    .trace_err()
 454                {
 455                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 456                    tracing::info!(
 457                        stale_channel_buffer_count = channel_ids.len(),
 458                        "retrieved stale channel buffers"
 459                    );
 460
 461                    for channel_id in channel_ids {
 462                        if let Some(refreshed_channel_buffer) = app_state
 463                            .db
 464                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 465                            .await
 466                            .trace_err()
 467                        {
 468                            for connection_id in refreshed_channel_buffer.connection_ids {
 469                                peer.send(
 470                                    connection_id,
 471                                    proto::UpdateChannelBufferCollaborators {
 472                                        channel_id: channel_id.to_proto(),
 473                                        collaborators: refreshed_channel_buffer
 474                                            .collaborators
 475                                            .clone(),
 476                                    },
 477                                )
 478                                .trace_err();
 479                            }
 480                        }
 481                    }
 482
 483                    for room_id in room_ids {
 484                        let mut contacts_to_update = HashSet::default();
 485                        let mut canceled_calls_to_user_ids = Vec::new();
 486                        let mut livekit_room = String::new();
 487                        let mut delete_livekit_room = false;
 488
 489                        if let Some(mut refreshed_room) = app_state
 490                            .db
 491                            .clear_stale_room_participants(room_id, server_id)
 492                            .await
 493                            .trace_err()
 494                        {
 495                            tracing::info!(
 496                                room_id = room_id.0,
 497                                new_participant_count = refreshed_room.room.participants.len(),
 498                                "refreshed room"
 499                            );
 500                            room_updated(&refreshed_room.room, &peer);
 501                            if let Some(channel) = refreshed_room.channel.as_ref() {
 502                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 503                            }
 504                            contacts_to_update
 505                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 506                            contacts_to_update
 507                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 508                            canceled_calls_to_user_ids =
 509                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 510                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 511                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 512                        }
 513
 514                        {
 515                            let pool = pool.lock();
 516                            for canceled_user_id in canceled_calls_to_user_ids {
 517                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 518                                    peer.send(
 519                                        connection_id,
 520                                        proto::CallCanceled {
 521                                            room_id: room_id.to_proto(),
 522                                        },
 523                                    )
 524                                    .trace_err();
 525                                }
 526                            }
 527                        }
 528
 529                        for user_id in contacts_to_update {
 530                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 531                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 532                            if let Some((busy, contacts)) = busy.zip(contacts) {
 533                                let pool = pool.lock();
 534                                let updated_contact = contact_for_user(user_id, busy, &pool);
 535                                for contact in contacts {
 536                                    if let db::Contact::Accepted {
 537                                        user_id: contact_user_id,
 538                                        ..
 539                                    } = contact
 540                                    {
 541                                        for contact_conn_id in
 542                                            pool.user_connection_ids(contact_user_id)
 543                                        {
 544                                            peer.send(
 545                                                contact_conn_id,
 546                                                proto::UpdateContacts {
 547                                                    contacts: vec![updated_contact.clone()],
 548                                                    remove_contacts: Default::default(),
 549                                                    incoming_requests: Default::default(),
 550                                                    remove_incoming_requests: Default::default(),
 551                                                    outgoing_requests: Default::default(),
 552                                                    remove_outgoing_requests: Default::default(),
 553                                                },
 554                                            )
 555                                            .trace_err();
 556                                        }
 557                                    }
 558                                }
 559                            }
 560                        }
 561
 562                        if let Some(live_kit) = livekit_client.as_ref() {
 563                            if delete_livekit_room {
 564                                live_kit.delete_room(livekit_room).await.trace_err();
 565                            }
 566                        }
 567                    }
 568                }
 569
 570                app_state
 571                    .db
 572                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 573                    .await
 574                    .trace_err();
 575            }
 576            .instrument(span),
 577        );
 578        Ok(())
 579    }
 580
 581    pub fn teardown(&self) {
 582        self.peer.teardown();
 583        self.connection_pool.lock().reset();
 584        let _ = self.teardown.send(true);
 585    }
 586
 587    #[cfg(test)]
 588    pub fn reset(&self, id: ServerId) {
 589        self.teardown();
 590        *self.id.lock() = id;
 591        self.peer.reset(id.0 as u32);
 592        let _ = self.teardown.send(false);
 593    }
 594
 595    #[cfg(test)]
 596    pub fn id(&self) -> ServerId {
 597        *self.id.lock()
 598    }
 599
 600    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 601    where
 602        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 603        Fut: 'static + Send + Future<Output = Result<()>>,
 604        M: EnvelopedMessage,
 605    {
 606        let prev_handler = self.handlers.insert(
 607            TypeId::of::<M>(),
 608            Box::new(move |envelope, session| {
 609                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 610                let received_at = envelope.received_at;
 611                tracing::info!("message received");
 612                let start_time = Instant::now();
 613                let future = (handler)(*envelope, session);
 614                async move {
 615                    let result = future.await;
 616                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 617                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 618                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 619                    let payload_type = M::NAME;
 620
 621                    match result {
 622                        Err(error) => {
 623                            tracing::error!(
 624                                ?error,
 625                                total_duration_ms,
 626                                processing_duration_ms,
 627                                queue_duration_ms,
 628                                payload_type,
 629                                "error handling message"
 630                            )
 631                        }
 632                        Ok(()) => tracing::info!(
 633                            total_duration_ms,
 634                            processing_duration_ms,
 635                            queue_duration_ms,
 636                            "finished handling message"
 637                        ),
 638                    }
 639                }
 640                .boxed()
 641            }),
 642        );
 643        if prev_handler.is_some() {
 644            panic!("registered a handler for the same message twice");
 645        }
 646        self
 647    }
 648
 649    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 650    where
 651        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 652        Fut: 'static + Send + Future<Output = Result<()>>,
 653        M: EnvelopedMessage,
 654    {
 655        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 656        self
 657    }
 658
 659    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 660    where
 661        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 662        Fut: Send + Future<Output = Result<()>>,
 663        M: RequestMessage,
 664    {
 665        let handler = Arc::new(handler);
 666        self.add_handler(move |envelope, session| {
 667            let receipt = envelope.receipt();
 668            let handler = handler.clone();
 669            async move {
 670                let peer = session.peer.clone();
 671                let responded = Arc::new(AtomicBool::default());
 672                let response = Response {
 673                    peer: peer.clone(),
 674                    responded: responded.clone(),
 675                    receipt,
 676                };
 677                match (handler)(envelope.payload, response, session).await {
 678                    Ok(()) => {
 679                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 680                            Ok(())
 681                        } else {
 682                            Err(anyhow!("handler did not send a response"))?
 683                        }
 684                    }
 685                    Err(error) => {
 686                        let proto_err = match &error {
 687                            Error::Internal(err) => err.to_proto(),
 688                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 689                        };
 690                        peer.respond_with_error(receipt, proto_err)?;
 691                        Err(error)
 692                    }
 693                }
 694            }
 695        })
 696    }
 697
 698    pub fn handle_connection(
 699        self: &Arc<Self>,
 700        connection: Connection,
 701        address: String,
 702        principal: Principal,
 703        zed_version: ZedVersion,
 704        geoip_country_code: Option<String>,
 705        system_id: Option<String>,
 706        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 707        executor: Executor,
 708    ) -> impl Future<Output = ()> + use<> {
 709        let this = self.clone();
 710        let span = info_span!("handle connection", %address,
 711            connection_id=field::Empty,
 712            user_id=field::Empty,
 713            login=field::Empty,
 714            impersonator=field::Empty,
 715            geoip_country_code=field::Empty
 716        );
 717        principal.update_span(&span);
 718        if let Some(country_code) = geoip_country_code.as_ref() {
 719            span.record("geoip_country_code", country_code);
 720        }
 721
 722        let mut teardown = self.teardown.subscribe();
 723        async move {
 724            if *teardown.borrow() {
 725                tracing::error!("server is tearing down");
 726                return
 727            }
 728            let (connection_id, handle_io, mut incoming_rx) = this
 729                .peer
 730                .add_connection(connection, {
 731                    let executor = executor.clone();
 732                    move |duration| executor.sleep(duration)
 733                });
 734            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 735
 736            tracing::info!("connection opened");
 737
 738            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 739            let http_client = match ReqwestClient::user_agent(&user_agent) {
 740                Ok(http_client) => Arc::new(http_client),
 741                Err(error) => {
 742                    tracing::error!(?error, "failed to create HTTP client");
 743                    return;
 744                }
 745            };
 746
 747            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 748                    supermaven_admin_api_key.to_string(),
 749                    http_client.clone(),
 750                )));
 751
 752            let session = Session {
 753                principal: principal.clone(),
 754                connection_id,
 755                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 756                peer: this.peer.clone(),
 757                connection_pool: this.connection_pool.clone(),
 758                app_state: this.app_state.clone(),
 759                geoip_country_code,
 760                system_id,
 761                _executor: executor.clone(),
 762                supermaven_client,
 763            };
 764
 765            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 766                tracing::error!(?error, "failed to send initial client update");
 767                return;
 768            }
 769
 770            let handle_io = handle_io.fuse();
 771            futures::pin_mut!(handle_io);
 772
 773            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 774            // This prevents deadlocks when e.g., client A performs a request to client B and
 775            // client B performs a request to client A. If both clients stop processing further
 776            // messages until their respective request completes, they won't have a chance to
 777            // respond to the other client's request and cause a deadlock.
 778            //
 779            // This arrangement ensures we will attempt to process earlier messages first, but fall
 780            // back to processing messages arrived later in the spirit of making progress.
 781            let mut foreground_message_handlers = FuturesUnordered::new();
 782            let concurrent_handlers = Arc::new(Semaphore::new(256));
 783            loop {
 784                let next_message = async {
 785                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 786                    let message = incoming_rx.next().await;
 787                    (permit, message)
 788                }.fuse();
 789                futures::pin_mut!(next_message);
 790                futures::select_biased! {
 791                    _ = teardown.changed().fuse() => return,
 792                    result = handle_io => {
 793                        if let Err(error) = result {
 794                            tracing::error!(?error, "error handling I/O");
 795                        }
 796                        break;
 797                    }
 798                    _ = foreground_message_handlers.next() => {}
 799                    next_message = next_message => {
 800                        let (permit, message) = next_message;
 801                        if let Some(message) = message {
 802                            let type_name = message.payload_type_name();
 803                            // note: we copy all the fields from the parent span so we can query them in the logs.
 804                            // (https://github.com/tokio-rs/tracing/issues/2670).
 805                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 806                                user_id=field::Empty,
 807                                login=field::Empty,
 808                                impersonator=field::Empty,
 809                            );
 810                            principal.update_span(&span);
 811                            let span_enter = span.enter();
 812                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 813                                let is_background = message.is_background();
 814                                let handle_message = (handler)(message, session.clone());
 815                                drop(span_enter);
 816
 817                                let handle_message = async move {
 818                                    handle_message.await;
 819                                    drop(permit);
 820                                }.instrument(span);
 821                                if is_background {
 822                                    executor.spawn_detached(handle_message);
 823                                } else {
 824                                    foreground_message_handlers.push(handle_message);
 825                                }
 826                            } else {
 827                                tracing::error!("no message handler");
 828                            }
 829                        } else {
 830                            tracing::info!("connection closed");
 831                            break;
 832                        }
 833                    }
 834                }
 835            }
 836
 837            drop(foreground_message_handlers);
 838            tracing::info!("signing out");
 839            if let Err(error) = connection_lost(session, teardown, executor).await {
 840                tracing::error!(?error, "error signing out");
 841            }
 842
 843        }.instrument(span)
 844    }
 845
 846    async fn send_initial_client_update(
 847        &self,
 848        connection_id: ConnectionId,
 849        principal: &Principal,
 850        zed_version: ZedVersion,
 851        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 852        session: &Session,
 853    ) -> Result<()> {
 854        self.peer.send(
 855            connection_id,
 856            proto::Hello {
 857                peer_id: Some(connection_id.into()),
 858            },
 859        )?;
 860        tracing::info!("sent hello message");
 861        if let Some(send_connection_id) = send_connection_id.take() {
 862            let _ = send_connection_id.send(connection_id);
 863        }
 864
 865        match principal {
 866            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 867                if !user.connected_once {
 868                    self.peer.send(connection_id, proto::ShowContacts {})?;
 869                    self.app_state
 870                        .db
 871                        .set_user_connected_once(user.id, true)
 872                        .await?;
 873                }
 874
 875                update_user_plan(user.id, session).await?;
 876
 877                let contacts = self.app_state.db.get_contacts(user.id).await?;
 878
 879                {
 880                    let mut pool = self.connection_pool.lock();
 881                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 882                    self.peer.send(
 883                        connection_id,
 884                        build_initial_contacts_update(contacts, &pool),
 885                    )?;
 886                }
 887
 888                if should_auto_subscribe_to_channels(zed_version) {
 889                    subscribe_user_to_channels(user.id, session).await?;
 890                }
 891
 892                if let Some(incoming_call) =
 893                    self.app_state.db.incoming_call_for_user(user.id).await?
 894                {
 895                    self.peer.send(connection_id, incoming_call)?;
 896                }
 897
 898                update_user_contacts(user.id, session).await?;
 899            }
 900        }
 901
 902        Ok(())
 903    }
 904
 905    pub async fn invite_code_redeemed(
 906        self: &Arc<Self>,
 907        inviter_id: UserId,
 908        invitee_id: UserId,
 909    ) -> Result<()> {
 910        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 911            if let Some(code) = &user.invite_code {
 912                let pool = self.connection_pool.lock();
 913                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 914                for connection_id in pool.user_connection_ids(inviter_id) {
 915                    self.peer.send(
 916                        connection_id,
 917                        proto::UpdateContacts {
 918                            contacts: vec![invitee_contact.clone()],
 919                            ..Default::default()
 920                        },
 921                    )?;
 922                    self.peer.send(
 923                        connection_id,
 924                        proto::UpdateInviteInfo {
 925                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 926                            count: user.invite_count as u32,
 927                        },
 928                    )?;
 929                }
 930            }
 931        }
 932        Ok(())
 933    }
 934
 935    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 936        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 937            if let Some(invite_code) = &user.invite_code {
 938                let pool = self.connection_pool.lock();
 939                for connection_id in pool.user_connection_ids(user_id) {
 940                    self.peer.send(
 941                        connection_id,
 942                        proto::UpdateInviteInfo {
 943                            url: format!(
 944                                "{}{}",
 945                                self.app_state.config.invite_link_prefix, invite_code
 946                            ),
 947                            count: user.invite_count as u32,
 948                        },
 949                    )?;
 950                }
 951            }
 952        }
 953        Ok(())
 954    }
 955
 956    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 957        let pool = self.connection_pool.lock();
 958        for connection_id in pool.user_connection_ids(user_id) {
 959            self.peer
 960                .send(connection_id, proto::RefreshLlmToken {})
 961                .trace_err();
 962        }
 963    }
 964
 965    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 966        ServerSnapshot {
 967            connection_pool: ConnectionPoolGuard {
 968                guard: self.connection_pool.lock(),
 969                _not_send: PhantomData,
 970            },
 971            peer: &self.peer,
 972        }
 973    }
 974}
 975
 976impl Deref for ConnectionPoolGuard<'_> {
 977    type Target = ConnectionPool;
 978
 979    fn deref(&self) -> &Self::Target {
 980        &self.guard
 981    }
 982}
 983
 984impl DerefMut for ConnectionPoolGuard<'_> {
 985    fn deref_mut(&mut self) -> &mut Self::Target {
 986        &mut self.guard
 987    }
 988}
 989
 990impl Drop for ConnectionPoolGuard<'_> {
 991    fn drop(&mut self) {
 992        #[cfg(test)]
 993        self.check_invariants();
 994    }
 995}
 996
 997fn broadcast<F>(
 998    sender_id: Option<ConnectionId>,
 999    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1000    mut f: F,
1001) where
1002    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1003{
1004    for receiver_id in receiver_ids {
1005        if Some(receiver_id) != sender_id {
1006            if let Err(error) = f(receiver_id) {
1007                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1008            }
1009        }
1010    }
1011}
1012
1013pub struct ProtocolVersion(u32);
1014
1015impl Header for ProtocolVersion {
1016    fn name() -> &'static HeaderName {
1017        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1018        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1019    }
1020
1021    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1022    where
1023        Self: Sized,
1024        I: Iterator<Item = &'i axum::http::HeaderValue>,
1025    {
1026        let version = values
1027            .next()
1028            .ok_or_else(axum::headers::Error::invalid)?
1029            .to_str()
1030            .map_err(|_| axum::headers::Error::invalid())?
1031            .parse()
1032            .map_err(|_| axum::headers::Error::invalid())?;
1033        Ok(Self(version))
1034    }
1035
1036    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1037        values.extend([self.0.to_string().parse().unwrap()]);
1038    }
1039}
1040
1041pub struct AppVersionHeader(SemanticVersion);
1042impl Header for AppVersionHeader {
1043    fn name() -> &'static HeaderName {
1044        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1045        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1046    }
1047
1048    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1049    where
1050        Self: Sized,
1051        I: Iterator<Item = &'i axum::http::HeaderValue>,
1052    {
1053        let version = values
1054            .next()
1055            .ok_or_else(axum::headers::Error::invalid)?
1056            .to_str()
1057            .map_err(|_| axum::headers::Error::invalid())?
1058            .parse()
1059            .map_err(|_| axum::headers::Error::invalid())?;
1060        Ok(Self(version))
1061    }
1062
1063    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1064        values.extend([self.0.to_string().parse().unwrap()]);
1065    }
1066}
1067
1068pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1069    Router::new()
1070        .route("/rpc", get(handle_websocket_request))
1071        .layer(
1072            ServiceBuilder::new()
1073                .layer(Extension(server.app_state.clone()))
1074                .layer(middleware::from_fn(auth::validate_header)),
1075        )
1076        .route("/metrics", get(handle_metrics))
1077        .layer(Extension(server))
1078}
1079
1080pub async fn handle_websocket_request(
1081    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1082    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1083    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1084    Extension(server): Extension<Arc<Server>>,
1085    Extension(principal): Extension<Principal>,
1086    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1087    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1088    ws: WebSocketUpgrade,
1089) -> axum::response::Response {
1090    if protocol_version != rpc::PROTOCOL_VERSION {
1091        return (
1092            StatusCode::UPGRADE_REQUIRED,
1093            "client must be upgraded".to_string(),
1094        )
1095            .into_response();
1096    }
1097
1098    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1099        return (
1100            StatusCode::UPGRADE_REQUIRED,
1101            "no version header found".to_string(),
1102        )
1103            .into_response();
1104    };
1105
1106    if !version.can_collaborate() {
1107        return (
1108            StatusCode::UPGRADE_REQUIRED,
1109            "client must be upgraded".to_string(),
1110        )
1111            .into_response();
1112    }
1113
1114    let socket_address = socket_address.to_string();
1115    ws.on_upgrade(move |socket| {
1116        let socket = socket
1117            .map_ok(to_tungstenite_message)
1118            .err_into()
1119            .with(|message| async move { to_axum_message(message) });
1120        let connection = Connection::new(Box::pin(socket));
1121        async move {
1122            server
1123                .handle_connection(
1124                    connection,
1125                    socket_address,
1126                    principal,
1127                    version,
1128                    country_code_header.map(|header| header.to_string()),
1129                    system_id_header.map(|header| header.to_string()),
1130                    None,
1131                    Executor::Production,
1132                )
1133                .await;
1134        }
1135    })
1136}
1137
1138pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1139    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1140    let connections_metric = CONNECTIONS_METRIC
1141        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1142
1143    let connections = server
1144        .connection_pool
1145        .lock()
1146        .connections()
1147        .filter(|connection| !connection.admin)
1148        .count();
1149    connections_metric.set(connections as _);
1150
1151    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1153        register_int_gauge!(
1154            "shared_projects",
1155            "number of open projects with one or more guests"
1156        )
1157        .unwrap()
1158    });
1159
1160    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1161    shared_projects_metric.set(shared_projects as _);
1162
1163    let encoder = prometheus::TextEncoder::new();
1164    let metric_families = prometheus::gather();
1165    let encoded_metrics = encoder
1166        .encode_to_string(&metric_families)
1167        .map_err(|err| anyhow!("{}", err))?;
1168    Ok(encoded_metrics)
1169}
1170
1171#[instrument(err, skip(executor))]
1172async fn connection_lost(
1173    session: Session,
1174    mut teardown: watch::Receiver<bool>,
1175    executor: Executor,
1176) -> Result<()> {
1177    session.peer.disconnect(session.connection_id);
1178    session
1179        .connection_pool()
1180        .await
1181        .remove_connection(session.connection_id)?;
1182
1183    session
1184        .db()
1185        .await
1186        .connection_lost(session.connection_id)
1187        .await
1188        .trace_err();
1189
1190    futures::select_biased! {
1191        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1192
1193            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1194            leave_room_for_session(&session, session.connection_id).await.trace_err();
1195            leave_channel_buffers_for_session(&session)
1196                .await
1197                .trace_err();
1198
1199            if !session
1200                .connection_pool()
1201                .await
1202                .is_user_online(session.user_id())
1203            {
1204                let db = session.db().await;
1205                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1206                    room_updated(&room, &session.peer);
1207                }
1208            }
1209
1210            update_user_contacts(session.user_id(), &session).await?;
1211        },
1212        _ = teardown.changed().fuse() => {}
1213    }
1214
1215    Ok(())
1216}
1217
1218/// Acknowledges a ping from a client, used to keep the connection alive.
1219async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1220    response.send(proto::Ack {})?;
1221    Ok(())
1222}
1223
1224/// Creates a new room for calling (outside of channels)
1225async fn create_room(
1226    _request: proto::CreateRoom,
1227    response: Response<proto::CreateRoom>,
1228    session: Session,
1229) -> Result<()> {
1230    let livekit_room = nanoid::nanoid!(30);
1231
1232    let live_kit_connection_info = util::maybe!(async {
1233        let live_kit = session.app_state.livekit_client.as_ref();
1234        let live_kit = live_kit?;
1235        let user_id = session.user_id().to_string();
1236
1237        let token = live_kit
1238            .room_token(&livekit_room, &user_id.to_string())
1239            .trace_err()?;
1240
1241        Some(proto::LiveKitConnectionInfo {
1242            server_url: live_kit.url().into(),
1243            token,
1244            can_publish: true,
1245        })
1246    })
1247    .await;
1248
1249    let room = session
1250        .db()
1251        .await
1252        .create_room(session.user_id(), session.connection_id, &livekit_room)
1253        .await?;
1254
1255    response.send(proto::CreateRoomResponse {
1256        room: Some(room.clone()),
1257        live_kit_connection_info,
1258    })?;
1259
1260    update_user_contacts(session.user_id(), &session).await?;
1261    Ok(())
1262}
1263
1264/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1265async fn join_room(
1266    request: proto::JoinRoom,
1267    response: Response<proto::JoinRoom>,
1268    session: Session,
1269) -> Result<()> {
1270    let room_id = RoomId::from_proto(request.id);
1271
1272    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1273
1274    if let Some(channel_id) = channel_id {
1275        return join_channel_internal(channel_id, Box::new(response), session).await;
1276    }
1277
1278    let joined_room = {
1279        let room = session
1280            .db()
1281            .await
1282            .join_room(room_id, session.user_id(), session.connection_id)
1283            .await?;
1284        room_updated(&room.room, &session.peer);
1285        room.into_inner()
1286    };
1287
1288    for connection_id in session
1289        .connection_pool()
1290        .await
1291        .user_connection_ids(session.user_id())
1292    {
1293        session
1294            .peer
1295            .send(
1296                connection_id,
1297                proto::CallCanceled {
1298                    room_id: room_id.to_proto(),
1299                },
1300            )
1301            .trace_err();
1302    }
1303
1304    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1305    {
1306        live_kit
1307            .room_token(
1308                &joined_room.room.livekit_room,
1309                &session.user_id().to_string(),
1310            )
1311            .trace_err()
1312            .map(|token| proto::LiveKitConnectionInfo {
1313                server_url: live_kit.url().into(),
1314                token,
1315                can_publish: true,
1316            })
1317    } else {
1318        None
1319    };
1320
1321    response.send(proto::JoinRoomResponse {
1322        room: Some(joined_room.room),
1323        channel_id: None,
1324        live_kit_connection_info,
1325    })?;
1326
1327    update_user_contacts(session.user_id(), &session).await?;
1328    Ok(())
1329}
1330
1331/// Rejoin room is used to reconnect to a room after connection errors.
1332async fn rejoin_room(
1333    request: proto::RejoinRoom,
1334    response: Response<proto::RejoinRoom>,
1335    session: Session,
1336) -> Result<()> {
1337    let room;
1338    let channel;
1339    {
1340        let mut rejoined_room = session
1341            .db()
1342            .await
1343            .rejoin_room(request, session.user_id(), session.connection_id)
1344            .await?;
1345
1346        response.send(proto::RejoinRoomResponse {
1347            room: Some(rejoined_room.room.clone()),
1348            reshared_projects: rejoined_room
1349                .reshared_projects
1350                .iter()
1351                .map(|project| proto::ResharedProject {
1352                    id: project.id.to_proto(),
1353                    collaborators: project
1354                        .collaborators
1355                        .iter()
1356                        .map(|collaborator| collaborator.to_proto())
1357                        .collect(),
1358                })
1359                .collect(),
1360            rejoined_projects: rejoined_room
1361                .rejoined_projects
1362                .iter()
1363                .map(|rejoined_project| rejoined_project.to_proto())
1364                .collect(),
1365        })?;
1366        room_updated(&rejoined_room.room, &session.peer);
1367
1368        for project in &rejoined_room.reshared_projects {
1369            for collaborator in &project.collaborators {
1370                session
1371                    .peer
1372                    .send(
1373                        collaborator.connection_id,
1374                        proto::UpdateProjectCollaborator {
1375                            project_id: project.id.to_proto(),
1376                            old_peer_id: Some(project.old_connection_id.into()),
1377                            new_peer_id: Some(session.connection_id.into()),
1378                        },
1379                    )
1380                    .trace_err();
1381            }
1382
1383            broadcast(
1384                Some(session.connection_id),
1385                project
1386                    .collaborators
1387                    .iter()
1388                    .map(|collaborator| collaborator.connection_id),
1389                |connection_id| {
1390                    session.peer.forward_send(
1391                        session.connection_id,
1392                        connection_id,
1393                        proto::UpdateProject {
1394                            project_id: project.id.to_proto(),
1395                            worktrees: project.worktrees.clone(),
1396                        },
1397                    )
1398                },
1399            );
1400        }
1401
1402        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1403
1404        let rejoined_room = rejoined_room.into_inner();
1405
1406        room = rejoined_room.room;
1407        channel = rejoined_room.channel;
1408    }
1409
1410    if let Some(channel) = channel {
1411        channel_updated(
1412            &channel,
1413            &room,
1414            &session.peer,
1415            &*session.connection_pool().await,
1416        );
1417    }
1418
1419    update_user_contacts(session.user_id(), &session).await?;
1420    Ok(())
1421}
1422
1423fn notify_rejoined_projects(
1424    rejoined_projects: &mut Vec<RejoinedProject>,
1425    session: &Session,
1426) -> Result<()> {
1427    for project in rejoined_projects.iter() {
1428        for collaborator in &project.collaborators {
1429            session
1430                .peer
1431                .send(
1432                    collaborator.connection_id,
1433                    proto::UpdateProjectCollaborator {
1434                        project_id: project.id.to_proto(),
1435                        old_peer_id: Some(project.old_connection_id.into()),
1436                        new_peer_id: Some(session.connection_id.into()),
1437                    },
1438                )
1439                .trace_err();
1440        }
1441    }
1442
1443    for project in rejoined_projects {
1444        for worktree in mem::take(&mut project.worktrees) {
1445            // Stream this worktree's entries.
1446            let message = proto::UpdateWorktree {
1447                project_id: project.id.to_proto(),
1448                worktree_id: worktree.id,
1449                abs_path: worktree.abs_path.clone(),
1450                root_name: worktree.root_name,
1451                updated_entries: worktree.updated_entries,
1452                removed_entries: worktree.removed_entries,
1453                scan_id: worktree.scan_id,
1454                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1455                updated_repositories: worktree.updated_repositories,
1456                removed_repositories: worktree.removed_repositories,
1457            };
1458            for update in proto::split_worktree_update(message) {
1459                session.peer.send(session.connection_id, update)?;
1460            }
1461
1462            // Stream this worktree's diagnostics.
1463            for summary in worktree.diagnostic_summaries {
1464                session.peer.send(
1465                    session.connection_id,
1466                    proto::UpdateDiagnosticSummary {
1467                        project_id: project.id.to_proto(),
1468                        worktree_id: worktree.id,
1469                        summary: Some(summary),
1470                    },
1471                )?;
1472            }
1473
1474            for settings_file in worktree.settings_files {
1475                session.peer.send(
1476                    session.connection_id,
1477                    proto::UpdateWorktreeSettings {
1478                        project_id: project.id.to_proto(),
1479                        worktree_id: worktree.id,
1480                        path: settings_file.path,
1481                        content: Some(settings_file.content),
1482                        kind: Some(settings_file.kind.to_proto().into()),
1483                    },
1484                )?;
1485            }
1486        }
1487
1488        for repository in mem::take(&mut project.updated_repositories) {
1489            for update in split_repository_update(repository) {
1490                session.peer.send(session.connection_id, update)?;
1491            }
1492        }
1493
1494        for id in mem::take(&mut project.removed_repositories) {
1495            session.peer.send(
1496                session.connection_id,
1497                proto::RemoveRepository {
1498                    project_id: project.id.to_proto(),
1499                    id,
1500                },
1501            )?;
1502        }
1503    }
1504
1505    Ok(())
1506}
1507
1508/// leave room disconnects from the room.
1509async fn leave_room(
1510    _: proto::LeaveRoom,
1511    response: Response<proto::LeaveRoom>,
1512    session: Session,
1513) -> Result<()> {
1514    leave_room_for_session(&session, session.connection_id).await?;
1515    response.send(proto::Ack {})?;
1516    Ok(())
1517}
1518
1519/// Updates the permissions of someone else in the room.
1520async fn set_room_participant_role(
1521    request: proto::SetRoomParticipantRole,
1522    response: Response<proto::SetRoomParticipantRole>,
1523    session: Session,
1524) -> Result<()> {
1525    let user_id = UserId::from_proto(request.user_id);
1526    let role = ChannelRole::from(request.role());
1527
1528    let (livekit_room, can_publish) = {
1529        let room = session
1530            .db()
1531            .await
1532            .set_room_participant_role(
1533                session.user_id(),
1534                RoomId::from_proto(request.room_id),
1535                user_id,
1536                role,
1537            )
1538            .await?;
1539
1540        let livekit_room = room.livekit_room.clone();
1541        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1542        room_updated(&room, &session.peer);
1543        (livekit_room, can_publish)
1544    };
1545
1546    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1547        live_kit
1548            .update_participant(
1549                livekit_room.clone(),
1550                request.user_id.to_string(),
1551                livekit_api::proto::ParticipantPermission {
1552                    can_subscribe: true,
1553                    can_publish,
1554                    can_publish_data: can_publish,
1555                    hidden: false,
1556                    recorder: false,
1557                },
1558            )
1559            .await
1560            .trace_err();
1561    }
1562
1563    response.send(proto::Ack {})?;
1564    Ok(())
1565}
1566
1567/// Call someone else into the current room
1568async fn call(
1569    request: proto::Call,
1570    response: Response<proto::Call>,
1571    session: Session,
1572) -> Result<()> {
1573    let room_id = RoomId::from_proto(request.room_id);
1574    let calling_user_id = session.user_id();
1575    let calling_connection_id = session.connection_id;
1576    let called_user_id = UserId::from_proto(request.called_user_id);
1577    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1578    if !session
1579        .db()
1580        .await
1581        .has_contact(calling_user_id, called_user_id)
1582        .await?
1583    {
1584        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1585    }
1586
1587    let incoming_call = {
1588        let (room, incoming_call) = &mut *session
1589            .db()
1590            .await
1591            .call(
1592                room_id,
1593                calling_user_id,
1594                calling_connection_id,
1595                called_user_id,
1596                initial_project_id,
1597            )
1598            .await?;
1599        room_updated(room, &session.peer);
1600        mem::take(incoming_call)
1601    };
1602    update_user_contacts(called_user_id, &session).await?;
1603
1604    let mut calls = session
1605        .connection_pool()
1606        .await
1607        .user_connection_ids(called_user_id)
1608        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1609        .collect::<FuturesUnordered<_>>();
1610
1611    while let Some(call_response) = calls.next().await {
1612        match call_response.as_ref() {
1613            Ok(_) => {
1614                response.send(proto::Ack {})?;
1615                return Ok(());
1616            }
1617            Err(_) => {
1618                call_response.trace_err();
1619            }
1620        }
1621    }
1622
1623    {
1624        let room = session
1625            .db()
1626            .await
1627            .call_failed(room_id, called_user_id)
1628            .await?;
1629        room_updated(&room, &session.peer);
1630    }
1631    update_user_contacts(called_user_id, &session).await?;
1632
1633    Err(anyhow!("failed to ring user"))?
1634}
1635
1636/// Cancel an outgoing call.
1637async fn cancel_call(
1638    request: proto::CancelCall,
1639    response: Response<proto::CancelCall>,
1640    session: Session,
1641) -> Result<()> {
1642    let called_user_id = UserId::from_proto(request.called_user_id);
1643    let room_id = RoomId::from_proto(request.room_id);
1644    {
1645        let room = session
1646            .db()
1647            .await
1648            .cancel_call(room_id, session.connection_id, called_user_id)
1649            .await?;
1650        room_updated(&room, &session.peer);
1651    }
1652
1653    for connection_id in session
1654        .connection_pool()
1655        .await
1656        .user_connection_ids(called_user_id)
1657    {
1658        session
1659            .peer
1660            .send(
1661                connection_id,
1662                proto::CallCanceled {
1663                    room_id: room_id.to_proto(),
1664                },
1665            )
1666            .trace_err();
1667    }
1668    response.send(proto::Ack {})?;
1669
1670    update_user_contacts(called_user_id, &session).await?;
1671    Ok(())
1672}
1673
1674/// Decline an incoming call.
1675async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1676    let room_id = RoomId::from_proto(message.room_id);
1677    {
1678        let room = session
1679            .db()
1680            .await
1681            .decline_call(Some(room_id), session.user_id())
1682            .await?
1683            .ok_or_else(|| anyhow!("failed to decline call"))?;
1684        room_updated(&room, &session.peer);
1685    }
1686
1687    for connection_id in session
1688        .connection_pool()
1689        .await
1690        .user_connection_ids(session.user_id())
1691    {
1692        session
1693            .peer
1694            .send(
1695                connection_id,
1696                proto::CallCanceled {
1697                    room_id: room_id.to_proto(),
1698                },
1699            )
1700            .trace_err();
1701    }
1702    update_user_contacts(session.user_id(), &session).await?;
1703    Ok(())
1704}
1705
1706/// Updates other participants in the room with your current location.
1707async fn update_participant_location(
1708    request: proto::UpdateParticipantLocation,
1709    response: Response<proto::UpdateParticipantLocation>,
1710    session: Session,
1711) -> Result<()> {
1712    let room_id = RoomId::from_proto(request.room_id);
1713    let location = request
1714        .location
1715        .ok_or_else(|| anyhow!("invalid location"))?;
1716
1717    let db = session.db().await;
1718    let room = db
1719        .update_room_participant_location(room_id, session.connection_id, location)
1720        .await?;
1721
1722    room_updated(&room, &session.peer);
1723    response.send(proto::Ack {})?;
1724    Ok(())
1725}
1726
1727/// Share a project into the room.
1728async fn share_project(
1729    request: proto::ShareProject,
1730    response: Response<proto::ShareProject>,
1731    session: Session,
1732) -> Result<()> {
1733    let (project_id, room) = &*session
1734        .db()
1735        .await
1736        .share_project(
1737            RoomId::from_proto(request.room_id),
1738            session.connection_id,
1739            &request.worktrees,
1740            request.is_ssh_project,
1741        )
1742        .await?;
1743    response.send(proto::ShareProjectResponse {
1744        project_id: project_id.to_proto(),
1745    })?;
1746    room_updated(room, &session.peer);
1747
1748    Ok(())
1749}
1750
1751/// Unshare a project from the room.
1752async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1753    let project_id = ProjectId::from_proto(message.project_id);
1754    unshare_project_internal(project_id, session.connection_id, &session).await
1755}
1756
1757async fn unshare_project_internal(
1758    project_id: ProjectId,
1759    connection_id: ConnectionId,
1760    session: &Session,
1761) -> Result<()> {
1762    let delete = {
1763        let room_guard = session
1764            .db()
1765            .await
1766            .unshare_project(project_id, connection_id)
1767            .await?;
1768
1769        let (delete, room, guest_connection_ids) = &*room_guard;
1770
1771        let message = proto::UnshareProject {
1772            project_id: project_id.to_proto(),
1773        };
1774
1775        broadcast(
1776            Some(connection_id),
1777            guest_connection_ids.iter().copied(),
1778            |conn_id| session.peer.send(conn_id, message.clone()),
1779        );
1780        if let Some(room) = room {
1781            room_updated(room, &session.peer);
1782        }
1783
1784        *delete
1785    };
1786
1787    if delete {
1788        let db = session.db().await;
1789        db.delete_project(project_id).await?;
1790    }
1791
1792    Ok(())
1793}
1794
1795/// Join someone elses shared project.
1796async fn join_project(
1797    request: proto::JoinProject,
1798    response: Response<proto::JoinProject>,
1799    session: Session,
1800) -> Result<()> {
1801    let project_id = ProjectId::from_proto(request.project_id);
1802
1803    tracing::info!(%project_id, "join project");
1804
1805    let db = session.db().await;
1806    let (project, replica_id) = &mut *db
1807        .join_project(project_id, session.connection_id, session.user_id())
1808        .await?;
1809    drop(db);
1810    tracing::info!(%project_id, "join remote project");
1811    join_project_internal(response, session, project, replica_id)
1812}
1813
1814trait JoinProjectInternalResponse {
1815    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1816}
1817impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1818    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1819        Response::<proto::JoinProject>::send(self, result)
1820    }
1821}
1822
1823fn join_project_internal(
1824    response: impl JoinProjectInternalResponse,
1825    session: Session,
1826    project: &mut Project,
1827    replica_id: &ReplicaId,
1828) -> Result<()> {
1829    let collaborators = project
1830        .collaborators
1831        .iter()
1832        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1833        .map(|collaborator| collaborator.to_proto())
1834        .collect::<Vec<_>>();
1835    let project_id = project.id;
1836    let guest_user_id = session.user_id();
1837
1838    let worktrees = project
1839        .worktrees
1840        .iter()
1841        .map(|(id, worktree)| proto::WorktreeMetadata {
1842            id: *id,
1843            root_name: worktree.root_name.clone(),
1844            visible: worktree.visible,
1845            abs_path: worktree.abs_path.clone(),
1846        })
1847        .collect::<Vec<_>>();
1848
1849    let add_project_collaborator = proto::AddProjectCollaborator {
1850        project_id: project_id.to_proto(),
1851        collaborator: Some(proto::Collaborator {
1852            peer_id: Some(session.connection_id.into()),
1853            replica_id: replica_id.0 as u32,
1854            user_id: guest_user_id.to_proto(),
1855            is_host: false,
1856        }),
1857    };
1858
1859    for collaborator in &collaborators {
1860        session
1861            .peer
1862            .send(
1863                collaborator.peer_id.unwrap().into(),
1864                add_project_collaborator.clone(),
1865            )
1866            .trace_err();
1867    }
1868
1869    // First, we send the metadata associated with each worktree.
1870    response.send(proto::JoinProjectResponse {
1871        project_id: project.id.0 as u64,
1872        worktrees: worktrees.clone(),
1873        replica_id: replica_id.0 as u32,
1874        collaborators: collaborators.clone(),
1875        language_servers: project.language_servers.clone(),
1876        role: project.role.into(),
1877    })?;
1878
1879    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1880        // Stream this worktree's entries.
1881        let message = proto::UpdateWorktree {
1882            project_id: project_id.to_proto(),
1883            worktree_id,
1884            abs_path: worktree.abs_path.clone(),
1885            root_name: worktree.root_name,
1886            updated_entries: worktree.entries,
1887            removed_entries: Default::default(),
1888            scan_id: worktree.scan_id,
1889            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1890            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1891            removed_repositories: Default::default(),
1892        };
1893        for update in proto::split_worktree_update(message) {
1894            session.peer.send(session.connection_id, update.clone())?;
1895        }
1896
1897        // Stream this worktree's diagnostics.
1898        for summary in worktree.diagnostic_summaries {
1899            session.peer.send(
1900                session.connection_id,
1901                proto::UpdateDiagnosticSummary {
1902                    project_id: project_id.to_proto(),
1903                    worktree_id: worktree.id,
1904                    summary: Some(summary),
1905                },
1906            )?;
1907        }
1908
1909        for settings_file in worktree.settings_files {
1910            session.peer.send(
1911                session.connection_id,
1912                proto::UpdateWorktreeSettings {
1913                    project_id: project_id.to_proto(),
1914                    worktree_id: worktree.id,
1915                    path: settings_file.path,
1916                    content: Some(settings_file.content),
1917                    kind: Some(settings_file.kind.to_proto() as i32),
1918                },
1919            )?;
1920        }
1921    }
1922
1923    for repository in mem::take(&mut project.repositories) {
1924        for update in split_repository_update(repository) {
1925            session.peer.send(session.connection_id, update)?;
1926        }
1927    }
1928
1929    for language_server in &project.language_servers {
1930        session.peer.send(
1931            session.connection_id,
1932            proto::UpdateLanguageServer {
1933                project_id: project_id.to_proto(),
1934                language_server_id: language_server.id,
1935                variant: Some(
1936                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1937                        proto::LspDiskBasedDiagnosticsUpdated {},
1938                    ),
1939                ),
1940            },
1941        )?;
1942    }
1943
1944    Ok(())
1945}
1946
1947/// Leave someone elses shared project.
1948async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1949    let sender_id = session.connection_id;
1950    let project_id = ProjectId::from_proto(request.project_id);
1951    let db = session.db().await;
1952
1953    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1954    tracing::info!(
1955        %project_id,
1956        "leave project"
1957    );
1958
1959    project_left(project, &session);
1960    if let Some(room) = room {
1961        room_updated(room, &session.peer);
1962    }
1963
1964    Ok(())
1965}
1966
1967/// Updates other participants with changes to the project
1968async fn update_project(
1969    request: proto::UpdateProject,
1970    response: Response<proto::UpdateProject>,
1971    session: Session,
1972) -> Result<()> {
1973    let project_id = ProjectId::from_proto(request.project_id);
1974    let (room, guest_connection_ids) = &*session
1975        .db()
1976        .await
1977        .update_project(project_id, session.connection_id, &request.worktrees)
1978        .await?;
1979    broadcast(
1980        Some(session.connection_id),
1981        guest_connection_ids.iter().copied(),
1982        |connection_id| {
1983            session
1984                .peer
1985                .forward_send(session.connection_id, connection_id, request.clone())
1986        },
1987    );
1988    if let Some(room) = room {
1989        room_updated(room, &session.peer);
1990    }
1991    response.send(proto::Ack {})?;
1992
1993    Ok(())
1994}
1995
1996/// Updates other participants with changes to the worktree
1997async fn update_worktree(
1998    request: proto::UpdateWorktree,
1999    response: Response<proto::UpdateWorktree>,
2000    session: Session,
2001) -> Result<()> {
2002    let guest_connection_ids = session
2003        .db()
2004        .await
2005        .update_worktree(&request, session.connection_id)
2006        .await?;
2007
2008    broadcast(
2009        Some(session.connection_id),
2010        guest_connection_ids.iter().copied(),
2011        |connection_id| {
2012            session
2013                .peer
2014                .forward_send(session.connection_id, connection_id, request.clone())
2015        },
2016    );
2017    response.send(proto::Ack {})?;
2018    Ok(())
2019}
2020
2021async fn update_repository(
2022    request: proto::UpdateRepository,
2023    response: Response<proto::UpdateRepository>,
2024    session: Session,
2025) -> Result<()> {
2026    let guest_connection_ids = session
2027        .db()
2028        .await
2029        .update_repository(&request, session.connection_id)
2030        .await?;
2031
2032    broadcast(
2033        Some(session.connection_id),
2034        guest_connection_ids.iter().copied(),
2035        |connection_id| {
2036            session
2037                .peer
2038                .forward_send(session.connection_id, connection_id, request.clone())
2039        },
2040    );
2041    response.send(proto::Ack {})?;
2042    Ok(())
2043}
2044
2045async fn remove_repository(
2046    request: proto::RemoveRepository,
2047    response: Response<proto::RemoveRepository>,
2048    session: Session,
2049) -> Result<()> {
2050    let guest_connection_ids = session
2051        .db()
2052        .await
2053        .remove_repository(&request, session.connection_id)
2054        .await?;
2055
2056    broadcast(
2057        Some(session.connection_id),
2058        guest_connection_ids.iter().copied(),
2059        |connection_id| {
2060            session
2061                .peer
2062                .forward_send(session.connection_id, connection_id, request.clone())
2063        },
2064    );
2065    response.send(proto::Ack {})?;
2066    Ok(())
2067}
2068
2069/// Updates other participants with changes to the diagnostics
2070async fn update_diagnostic_summary(
2071    message: proto::UpdateDiagnosticSummary,
2072    session: Session,
2073) -> Result<()> {
2074    let guest_connection_ids = session
2075        .db()
2076        .await
2077        .update_diagnostic_summary(&message, session.connection_id)
2078        .await?;
2079
2080    broadcast(
2081        Some(session.connection_id),
2082        guest_connection_ids.iter().copied(),
2083        |connection_id| {
2084            session
2085                .peer
2086                .forward_send(session.connection_id, connection_id, message.clone())
2087        },
2088    );
2089
2090    Ok(())
2091}
2092
2093/// Updates other participants with changes to the worktree settings
2094async fn update_worktree_settings(
2095    message: proto::UpdateWorktreeSettings,
2096    session: Session,
2097) -> Result<()> {
2098    let guest_connection_ids = session
2099        .db()
2100        .await
2101        .update_worktree_settings(&message, session.connection_id)
2102        .await?;
2103
2104    broadcast(
2105        Some(session.connection_id),
2106        guest_connection_ids.iter().copied(),
2107        |connection_id| {
2108            session
2109                .peer
2110                .forward_send(session.connection_id, connection_id, message.clone())
2111        },
2112    );
2113
2114    Ok(())
2115}
2116
2117/// Notify other participants that a language server has started.
2118async fn start_language_server(
2119    request: proto::StartLanguageServer,
2120    session: Session,
2121) -> Result<()> {
2122    let guest_connection_ids = session
2123        .db()
2124        .await
2125        .start_language_server(&request, session.connection_id)
2126        .await?;
2127
2128    broadcast(
2129        Some(session.connection_id),
2130        guest_connection_ids.iter().copied(),
2131        |connection_id| {
2132            session
2133                .peer
2134                .forward_send(session.connection_id, connection_id, request.clone())
2135        },
2136    );
2137    Ok(())
2138}
2139
2140/// Notify other participants that a language server has changed.
2141async fn update_language_server(
2142    request: proto::UpdateLanguageServer,
2143    session: Session,
2144) -> Result<()> {
2145    let project_id = ProjectId::from_proto(request.project_id);
2146    let project_connection_ids = session
2147        .db()
2148        .await
2149        .project_connection_ids(project_id, session.connection_id, true)
2150        .await?;
2151    broadcast(
2152        Some(session.connection_id),
2153        project_connection_ids.iter().copied(),
2154        |connection_id| {
2155            session
2156                .peer
2157                .forward_send(session.connection_id, connection_id, request.clone())
2158        },
2159    );
2160    Ok(())
2161}
2162
2163/// forward a project request to the host. These requests should be read only
2164/// as guests are allowed to send them.
2165async fn forward_read_only_project_request<T>(
2166    request: T,
2167    response: Response<T>,
2168    session: Session,
2169) -> Result<()>
2170where
2171    T: EntityMessage + RequestMessage,
2172{
2173    let project_id = ProjectId::from_proto(request.remote_entity_id());
2174    let host_connection_id = session
2175        .db()
2176        .await
2177        .host_for_read_only_project_request(project_id, session.connection_id)
2178        .await?;
2179    let payload = session
2180        .peer
2181        .forward_request(session.connection_id, host_connection_id, request)
2182        .await?;
2183    response.send(payload)?;
2184    Ok(())
2185}
2186
2187async fn forward_find_search_candidates_request(
2188    request: proto::FindSearchCandidates,
2189    response: Response<proto::FindSearchCandidates>,
2190    session: Session,
2191) -> Result<()> {
2192    let project_id = ProjectId::from_proto(request.remote_entity_id());
2193    let host_connection_id = session
2194        .db()
2195        .await
2196        .host_for_read_only_project_request(project_id, session.connection_id)
2197        .await?;
2198    let payload = session
2199        .peer
2200        .forward_request(session.connection_id, host_connection_id, request)
2201        .await?;
2202    response.send(payload)?;
2203    Ok(())
2204}
2205
2206/// forward a project request to the host. These requests are disallowed
2207/// for guests.
2208async fn forward_mutating_project_request<T>(
2209    request: T,
2210    response: Response<T>,
2211    session: Session,
2212) -> Result<()>
2213where
2214    T: EntityMessage + RequestMessage,
2215{
2216    let project_id = ProjectId::from_proto(request.remote_entity_id());
2217
2218    let host_connection_id = session
2219        .db()
2220        .await
2221        .host_for_mutating_project_request(project_id, session.connection_id)
2222        .await?;
2223    let payload = session
2224        .peer
2225        .forward_request(session.connection_id, host_connection_id, request)
2226        .await?;
2227    response.send(payload)?;
2228    Ok(())
2229}
2230
2231/// Notify other participants that a new buffer has been created
2232async fn create_buffer_for_peer(
2233    request: proto::CreateBufferForPeer,
2234    session: Session,
2235) -> Result<()> {
2236    session
2237        .db()
2238        .await
2239        .check_user_is_project_host(
2240            ProjectId::from_proto(request.project_id),
2241            session.connection_id,
2242        )
2243        .await?;
2244    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2245    session
2246        .peer
2247        .forward_send(session.connection_id, peer_id.into(), request)?;
2248    Ok(())
2249}
2250
2251/// Notify other participants that a buffer has been updated. This is
2252/// allowed for guests as long as the update is limited to selections.
2253async fn update_buffer(
2254    request: proto::UpdateBuffer,
2255    response: Response<proto::UpdateBuffer>,
2256    session: Session,
2257) -> Result<()> {
2258    let project_id = ProjectId::from_proto(request.project_id);
2259    let mut capability = Capability::ReadOnly;
2260
2261    for op in request.operations.iter() {
2262        match op.variant {
2263            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2264            Some(_) => capability = Capability::ReadWrite,
2265        }
2266    }
2267
2268    let host = {
2269        let guard = session
2270            .db()
2271            .await
2272            .connections_for_buffer_update(project_id, session.connection_id, capability)
2273            .await?;
2274
2275        let (host, guests) = &*guard;
2276
2277        broadcast(
2278            Some(session.connection_id),
2279            guests.clone(),
2280            |connection_id| {
2281                session
2282                    .peer
2283                    .forward_send(session.connection_id, connection_id, request.clone())
2284            },
2285        );
2286
2287        *host
2288    };
2289
2290    if host != session.connection_id {
2291        session
2292            .peer
2293            .forward_request(session.connection_id, host, request.clone())
2294            .await?;
2295    }
2296
2297    response.send(proto::Ack {})?;
2298    Ok(())
2299}
2300
2301async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2302    let project_id = ProjectId::from_proto(message.project_id);
2303
2304    let operation = message.operation.as_ref().context("invalid operation")?;
2305    let capability = match operation.variant.as_ref() {
2306        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2307            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2308                match buffer_op.variant {
2309                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2310                        Capability::ReadOnly
2311                    }
2312                    _ => Capability::ReadWrite,
2313                }
2314            } else {
2315                Capability::ReadWrite
2316            }
2317        }
2318        Some(_) => Capability::ReadWrite,
2319        None => Capability::ReadOnly,
2320    };
2321
2322    let guard = session
2323        .db()
2324        .await
2325        .connections_for_buffer_update(project_id, session.connection_id, capability)
2326        .await?;
2327
2328    let (host, guests) = &*guard;
2329
2330    broadcast(
2331        Some(session.connection_id),
2332        guests.iter().chain([host]).copied(),
2333        |connection_id| {
2334            session
2335                .peer
2336                .forward_send(session.connection_id, connection_id, message.clone())
2337        },
2338    );
2339
2340    Ok(())
2341}
2342
2343/// Notify other participants that a project has been updated.
2344async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2345    request: T,
2346    session: Session,
2347) -> Result<()> {
2348    let project_id = ProjectId::from_proto(request.remote_entity_id());
2349    let project_connection_ids = session
2350        .db()
2351        .await
2352        .project_connection_ids(project_id, session.connection_id, false)
2353        .await?;
2354
2355    broadcast(
2356        Some(session.connection_id),
2357        project_connection_ids.iter().copied(),
2358        |connection_id| {
2359            session
2360                .peer
2361                .forward_send(session.connection_id, connection_id, request.clone())
2362        },
2363    );
2364    Ok(())
2365}
2366
2367/// Start following another user in a call.
2368async fn follow(
2369    request: proto::Follow,
2370    response: Response<proto::Follow>,
2371    session: Session,
2372) -> Result<()> {
2373    let room_id = RoomId::from_proto(request.room_id);
2374    let project_id = request.project_id.map(ProjectId::from_proto);
2375    let leader_id = request
2376        .leader_id
2377        .ok_or_else(|| anyhow!("invalid leader id"))?
2378        .into();
2379    let follower_id = session.connection_id;
2380
2381    session
2382        .db()
2383        .await
2384        .check_room_participants(room_id, leader_id, session.connection_id)
2385        .await?;
2386
2387    let response_payload = session
2388        .peer
2389        .forward_request(session.connection_id, leader_id, request)
2390        .await?;
2391    response.send(response_payload)?;
2392
2393    if let Some(project_id) = project_id {
2394        let room = session
2395            .db()
2396            .await
2397            .follow(room_id, project_id, leader_id, follower_id)
2398            .await?;
2399        room_updated(&room, &session.peer);
2400    }
2401
2402    Ok(())
2403}
2404
2405/// Stop following another user in a call.
2406async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2407    let room_id = RoomId::from_proto(request.room_id);
2408    let project_id = request.project_id.map(ProjectId::from_proto);
2409    let leader_id = request
2410        .leader_id
2411        .ok_or_else(|| anyhow!("invalid leader id"))?
2412        .into();
2413    let follower_id = session.connection_id;
2414
2415    session
2416        .db()
2417        .await
2418        .check_room_participants(room_id, leader_id, session.connection_id)
2419        .await?;
2420
2421    session
2422        .peer
2423        .forward_send(session.connection_id, leader_id, request)?;
2424
2425    if let Some(project_id) = project_id {
2426        let room = session
2427            .db()
2428            .await
2429            .unfollow(room_id, project_id, leader_id, follower_id)
2430            .await?;
2431        room_updated(&room, &session.peer);
2432    }
2433
2434    Ok(())
2435}
2436
2437/// Notify everyone following you of your current location.
2438async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2439    let room_id = RoomId::from_proto(request.room_id);
2440    let database = session.db.lock().await;
2441
2442    let connection_ids = if let Some(project_id) = request.project_id {
2443        let project_id = ProjectId::from_proto(project_id);
2444        database
2445            .project_connection_ids(project_id, session.connection_id, true)
2446            .await?
2447    } else {
2448        database
2449            .room_connection_ids(room_id, session.connection_id)
2450            .await?
2451    };
2452
2453    // For now, don't send view update messages back to that view's current leader.
2454    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2455        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2456        _ => None,
2457    });
2458
2459    for connection_id in connection_ids.iter().cloned() {
2460        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2461            session
2462                .peer
2463                .forward_send(session.connection_id, connection_id, request.clone())?;
2464        }
2465    }
2466    Ok(())
2467}
2468
2469/// Get public data about users.
2470async fn get_users(
2471    request: proto::GetUsers,
2472    response: Response<proto::GetUsers>,
2473    session: Session,
2474) -> Result<()> {
2475    let user_ids = request
2476        .user_ids
2477        .into_iter()
2478        .map(UserId::from_proto)
2479        .collect();
2480    let users = session
2481        .db()
2482        .await
2483        .get_users_by_ids(user_ids)
2484        .await?
2485        .into_iter()
2486        .map(|user| proto::User {
2487            id: user.id.to_proto(),
2488            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2489            github_login: user.github_login,
2490            email: user.email_address,
2491            name: user.name,
2492        })
2493        .collect();
2494    response.send(proto::UsersResponse { users })?;
2495    Ok(())
2496}
2497
2498/// Search for users (to invite) buy Github login
2499async fn fuzzy_search_users(
2500    request: proto::FuzzySearchUsers,
2501    response: Response<proto::FuzzySearchUsers>,
2502    session: Session,
2503) -> Result<()> {
2504    let query = request.query;
2505    let users = match query.len() {
2506        0 => vec![],
2507        1 | 2 => session
2508            .db()
2509            .await
2510            .get_user_by_github_login(&query)
2511            .await?
2512            .into_iter()
2513            .collect(),
2514        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2515    };
2516    let users = users
2517        .into_iter()
2518        .filter(|user| user.id != session.user_id())
2519        .map(|user| proto::User {
2520            id: user.id.to_proto(),
2521            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2522            github_login: user.github_login,
2523            name: user.name,
2524            email: user.email_address,
2525        })
2526        .collect();
2527    response.send(proto::UsersResponse { users })?;
2528    Ok(())
2529}
2530
2531/// Send a contact request to another user.
2532async fn request_contact(
2533    request: proto::RequestContact,
2534    response: Response<proto::RequestContact>,
2535    session: Session,
2536) -> Result<()> {
2537    let requester_id = session.user_id();
2538    let responder_id = UserId::from_proto(request.responder_id);
2539    if requester_id == responder_id {
2540        return Err(anyhow!("cannot add yourself as a contact"))?;
2541    }
2542
2543    let notifications = session
2544        .db()
2545        .await
2546        .send_contact_request(requester_id, responder_id)
2547        .await?;
2548
2549    // Update outgoing contact requests of requester
2550    let mut update = proto::UpdateContacts::default();
2551    update.outgoing_requests.push(responder_id.to_proto());
2552    for connection_id in session
2553        .connection_pool()
2554        .await
2555        .user_connection_ids(requester_id)
2556    {
2557        session.peer.send(connection_id, update.clone())?;
2558    }
2559
2560    // Update incoming contact requests of responder
2561    let mut update = proto::UpdateContacts::default();
2562    update
2563        .incoming_requests
2564        .push(proto::IncomingContactRequest {
2565            requester_id: requester_id.to_proto(),
2566        });
2567    let connection_pool = session.connection_pool().await;
2568    for connection_id in connection_pool.user_connection_ids(responder_id) {
2569        session.peer.send(connection_id, update.clone())?;
2570    }
2571
2572    send_notifications(&connection_pool, &session.peer, notifications);
2573
2574    response.send(proto::Ack {})?;
2575    Ok(())
2576}
2577
2578/// Accept or decline a contact request
2579async fn respond_to_contact_request(
2580    request: proto::RespondToContactRequest,
2581    response: Response<proto::RespondToContactRequest>,
2582    session: Session,
2583) -> Result<()> {
2584    let responder_id = session.user_id();
2585    let requester_id = UserId::from_proto(request.requester_id);
2586    let db = session.db().await;
2587    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2588        db.dismiss_contact_notification(responder_id, requester_id)
2589            .await?;
2590    } else {
2591        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2592
2593        let notifications = db
2594            .respond_to_contact_request(responder_id, requester_id, accept)
2595            .await?;
2596        let requester_busy = db.is_user_busy(requester_id).await?;
2597        let responder_busy = db.is_user_busy(responder_id).await?;
2598
2599        let pool = session.connection_pool().await;
2600        // Update responder with new contact
2601        let mut update = proto::UpdateContacts::default();
2602        if accept {
2603            update
2604                .contacts
2605                .push(contact_for_user(requester_id, requester_busy, &pool));
2606        }
2607        update
2608            .remove_incoming_requests
2609            .push(requester_id.to_proto());
2610        for connection_id in pool.user_connection_ids(responder_id) {
2611            session.peer.send(connection_id, update.clone())?;
2612        }
2613
2614        // Update requester with new contact
2615        let mut update = proto::UpdateContacts::default();
2616        if accept {
2617            update
2618                .contacts
2619                .push(contact_for_user(responder_id, responder_busy, &pool));
2620        }
2621        update
2622            .remove_outgoing_requests
2623            .push(responder_id.to_proto());
2624
2625        for connection_id in pool.user_connection_ids(requester_id) {
2626            session.peer.send(connection_id, update.clone())?;
2627        }
2628
2629        send_notifications(&pool, &session.peer, notifications);
2630    }
2631
2632    response.send(proto::Ack {})?;
2633    Ok(())
2634}
2635
2636/// Remove a contact.
2637async fn remove_contact(
2638    request: proto::RemoveContact,
2639    response: Response<proto::RemoveContact>,
2640    session: Session,
2641) -> Result<()> {
2642    let requester_id = session.user_id();
2643    let responder_id = UserId::from_proto(request.user_id);
2644    let db = session.db().await;
2645    let (contact_accepted, deleted_notification_id) =
2646        db.remove_contact(requester_id, responder_id).await?;
2647
2648    let pool = session.connection_pool().await;
2649    // Update outgoing contact requests of requester
2650    let mut update = proto::UpdateContacts::default();
2651    if contact_accepted {
2652        update.remove_contacts.push(responder_id.to_proto());
2653    } else {
2654        update
2655            .remove_outgoing_requests
2656            .push(responder_id.to_proto());
2657    }
2658    for connection_id in pool.user_connection_ids(requester_id) {
2659        session.peer.send(connection_id, update.clone())?;
2660    }
2661
2662    // Update incoming contact requests of responder
2663    let mut update = proto::UpdateContacts::default();
2664    if contact_accepted {
2665        update.remove_contacts.push(requester_id.to_proto());
2666    } else {
2667        update
2668            .remove_incoming_requests
2669            .push(requester_id.to_proto());
2670    }
2671    for connection_id in pool.user_connection_ids(responder_id) {
2672        session.peer.send(connection_id, update.clone())?;
2673        if let Some(notification_id) = deleted_notification_id {
2674            session.peer.send(
2675                connection_id,
2676                proto::DeleteNotification {
2677                    notification_id: notification_id.to_proto(),
2678                },
2679            )?;
2680        }
2681    }
2682
2683    response.send(proto::Ack {})?;
2684    Ok(())
2685}
2686
2687fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2688    version.0.minor() < 139
2689}
2690
2691async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2692    let db = session.db().await;
2693
2694    let feature_flags = db.get_user_flags(user_id).await?;
2695    let plan = session.current_plan(&db).await?;
2696    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2697    let billing_preferences = db.get_billing_preferences(user_id).await?;
2698
2699    let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
2700        let subscription = db.get_active_billing_subscription(user_id).await?;
2701
2702        let subscription_period = crate::db::billing_subscription::Model::current_period(
2703            subscription,
2704            session.is_staff(),
2705        );
2706
2707        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2708            llm_db
2709                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2710                .await?
2711        } else {
2712            None
2713        };
2714
2715        (subscription_period, usage)
2716    } else {
2717        (None, None)
2718    };
2719
2720    session
2721        .peer
2722        .send(
2723            session.connection_id,
2724            proto::UpdateUserPlan {
2725                plan: plan.into(),
2726                trial_started_at: billing_customer
2727                    .and_then(|billing_customer| billing_customer.trial_started_at)
2728                    .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2729                is_usage_based_billing_enabled: if session.is_staff() {
2730                    Some(true)
2731                } else {
2732                    billing_preferences
2733                        .map(|preferences| preferences.model_request_overages_enabled)
2734                },
2735                subscription_period: subscription_period.map(|(started_at, ended_at)| {
2736                    proto::SubscriptionPeriod {
2737                        started_at: started_at.timestamp() as u64,
2738                        ended_at: ended_at.timestamp() as u64,
2739                    }
2740                }),
2741                usage: usage.map(|usage| {
2742                    let plan = match plan {
2743                        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2744                        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2745                        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2746                    };
2747
2748                    let model_requests_limit = match plan.model_requests_limit() {
2749                        zed_llm_client::UsageLimit::Limited(limit) => {
2750                            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2751                                && feature_flags
2752                                    .iter()
2753                                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2754                            {
2755                                1_000
2756                            } else {
2757                                limit
2758                            };
2759
2760                            zed_llm_client::UsageLimit::Limited(limit)
2761                        }
2762                        zed_llm_client::UsageLimit::Unlimited => {
2763                            zed_llm_client::UsageLimit::Unlimited
2764                        }
2765                    };
2766
2767                    proto::SubscriptionUsage {
2768                        model_requests_usage_amount: usage.model_requests as u32,
2769                        model_requests_usage_limit: Some(proto::UsageLimit {
2770                            variant: Some(match model_requests_limit {
2771                                zed_llm_client::UsageLimit::Limited(limit) => {
2772                                    proto::usage_limit::Variant::Limited(
2773                                        proto::usage_limit::Limited {
2774                                            limit: limit as u32,
2775                                        },
2776                                    )
2777                                }
2778                                zed_llm_client::UsageLimit::Unlimited => {
2779                                    proto::usage_limit::Variant::Unlimited(
2780                                        proto::usage_limit::Unlimited {},
2781                                    )
2782                                }
2783                            }),
2784                        }),
2785                        edit_predictions_usage_amount: usage.edit_predictions as u32,
2786                        edit_predictions_usage_limit: Some(proto::UsageLimit {
2787                            variant: Some(match plan.edit_predictions_limit() {
2788                                zed_llm_client::UsageLimit::Limited(limit) => {
2789                                    proto::usage_limit::Variant::Limited(
2790                                        proto::usage_limit::Limited {
2791                                            limit: limit as u32,
2792                                        },
2793                                    )
2794                                }
2795                                zed_llm_client::UsageLimit::Unlimited => {
2796                                    proto::usage_limit::Variant::Unlimited(
2797                                        proto::usage_limit::Unlimited {},
2798                                    )
2799                                }
2800                            }),
2801                        }),
2802                    }
2803                }),
2804            },
2805        )
2806        .trace_err();
2807
2808    Ok(())
2809}
2810
2811async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2812    subscribe_user_to_channels(session.user_id(), &session).await?;
2813    Ok(())
2814}
2815
2816async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2817    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2818    let mut pool = session.connection_pool().await;
2819    for membership in &channels_for_user.channel_memberships {
2820        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2821    }
2822    session.peer.send(
2823        session.connection_id,
2824        build_update_user_channels(&channels_for_user),
2825    )?;
2826    session.peer.send(
2827        session.connection_id,
2828        build_channels_update(channels_for_user),
2829    )?;
2830    Ok(())
2831}
2832
2833/// Creates a new channel.
2834async fn create_channel(
2835    request: proto::CreateChannel,
2836    response: Response<proto::CreateChannel>,
2837    session: Session,
2838) -> Result<()> {
2839    let db = session.db().await;
2840
2841    let parent_id = request.parent_id.map(ChannelId::from_proto);
2842    let (channel, membership) = db
2843        .create_channel(&request.name, parent_id, session.user_id())
2844        .await?;
2845
2846    let root_id = channel.root_id();
2847    let channel = Channel::from_model(channel);
2848
2849    response.send(proto::CreateChannelResponse {
2850        channel: Some(channel.to_proto()),
2851        parent_id: request.parent_id,
2852    })?;
2853
2854    let mut connection_pool = session.connection_pool().await;
2855    if let Some(membership) = membership {
2856        connection_pool.subscribe_to_channel(
2857            membership.user_id,
2858            membership.channel_id,
2859            membership.role,
2860        );
2861        let update = proto::UpdateUserChannels {
2862            channel_memberships: vec![proto::ChannelMembership {
2863                channel_id: membership.channel_id.to_proto(),
2864                role: membership.role.into(),
2865            }],
2866            ..Default::default()
2867        };
2868        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2869            session.peer.send(connection_id, update.clone())?;
2870        }
2871    }
2872
2873    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2874        if !role.can_see_channel(channel.visibility) {
2875            continue;
2876        }
2877
2878        let update = proto::UpdateChannels {
2879            channels: vec![channel.to_proto()],
2880            ..Default::default()
2881        };
2882        session.peer.send(connection_id, update.clone())?;
2883    }
2884
2885    Ok(())
2886}
2887
2888/// Delete a channel
2889async fn delete_channel(
2890    request: proto::DeleteChannel,
2891    response: Response<proto::DeleteChannel>,
2892    session: Session,
2893) -> Result<()> {
2894    let db = session.db().await;
2895
2896    let channel_id = request.channel_id;
2897    let (root_channel, removed_channels) = db
2898        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2899        .await?;
2900    response.send(proto::Ack {})?;
2901
2902    // Notify members of removed channels
2903    let mut update = proto::UpdateChannels::default();
2904    update
2905        .delete_channels
2906        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2907
2908    let connection_pool = session.connection_pool().await;
2909    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2910        session.peer.send(connection_id, update.clone())?;
2911    }
2912
2913    Ok(())
2914}
2915
2916/// Invite someone to join a channel.
2917async fn invite_channel_member(
2918    request: proto::InviteChannelMember,
2919    response: Response<proto::InviteChannelMember>,
2920    session: Session,
2921) -> Result<()> {
2922    let db = session.db().await;
2923    let channel_id = ChannelId::from_proto(request.channel_id);
2924    let invitee_id = UserId::from_proto(request.user_id);
2925    let InviteMemberResult {
2926        channel,
2927        notifications,
2928    } = db
2929        .invite_channel_member(
2930            channel_id,
2931            invitee_id,
2932            session.user_id(),
2933            request.role().into(),
2934        )
2935        .await?;
2936
2937    let update = proto::UpdateChannels {
2938        channel_invitations: vec![channel.to_proto()],
2939        ..Default::default()
2940    };
2941
2942    let connection_pool = session.connection_pool().await;
2943    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2944        session.peer.send(connection_id, update.clone())?;
2945    }
2946
2947    send_notifications(&connection_pool, &session.peer, notifications);
2948
2949    response.send(proto::Ack {})?;
2950    Ok(())
2951}
2952
2953/// remove someone from a channel
2954async fn remove_channel_member(
2955    request: proto::RemoveChannelMember,
2956    response: Response<proto::RemoveChannelMember>,
2957    session: Session,
2958) -> Result<()> {
2959    let db = session.db().await;
2960    let channel_id = ChannelId::from_proto(request.channel_id);
2961    let member_id = UserId::from_proto(request.user_id);
2962
2963    let RemoveChannelMemberResult {
2964        membership_update,
2965        notification_id,
2966    } = db
2967        .remove_channel_member(channel_id, member_id, session.user_id())
2968        .await?;
2969
2970    let mut connection_pool = session.connection_pool().await;
2971    notify_membership_updated(
2972        &mut connection_pool,
2973        membership_update,
2974        member_id,
2975        &session.peer,
2976    );
2977    for connection_id in connection_pool.user_connection_ids(member_id) {
2978        if let Some(notification_id) = notification_id {
2979            session
2980                .peer
2981                .send(
2982                    connection_id,
2983                    proto::DeleteNotification {
2984                        notification_id: notification_id.to_proto(),
2985                    },
2986                )
2987                .trace_err();
2988        }
2989    }
2990
2991    response.send(proto::Ack {})?;
2992    Ok(())
2993}
2994
2995/// Toggle the channel between public and private.
2996/// Care is taken to maintain the invariant that public channels only descend from public channels,
2997/// (though members-only channels can appear at any point in the hierarchy).
2998async fn set_channel_visibility(
2999    request: proto::SetChannelVisibility,
3000    response: Response<proto::SetChannelVisibility>,
3001    session: Session,
3002) -> Result<()> {
3003    let db = session.db().await;
3004    let channel_id = ChannelId::from_proto(request.channel_id);
3005    let visibility = request.visibility().into();
3006
3007    let channel_model = db
3008        .set_channel_visibility(channel_id, visibility, session.user_id())
3009        .await?;
3010    let root_id = channel_model.root_id();
3011    let channel = Channel::from_model(channel_model);
3012
3013    let mut connection_pool = session.connection_pool().await;
3014    for (user_id, role) in connection_pool
3015        .channel_user_ids(root_id)
3016        .collect::<Vec<_>>()
3017        .into_iter()
3018    {
3019        let update = if role.can_see_channel(channel.visibility) {
3020            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3021            proto::UpdateChannels {
3022                channels: vec![channel.to_proto()],
3023                ..Default::default()
3024            }
3025        } else {
3026            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3027            proto::UpdateChannels {
3028                delete_channels: vec![channel.id.to_proto()],
3029                ..Default::default()
3030            }
3031        };
3032
3033        for connection_id in connection_pool.user_connection_ids(user_id) {
3034            session.peer.send(connection_id, update.clone())?;
3035        }
3036    }
3037
3038    response.send(proto::Ack {})?;
3039    Ok(())
3040}
3041
3042/// Alter the role for a user in the channel.
3043async fn set_channel_member_role(
3044    request: proto::SetChannelMemberRole,
3045    response: Response<proto::SetChannelMemberRole>,
3046    session: Session,
3047) -> Result<()> {
3048    let db = session.db().await;
3049    let channel_id = ChannelId::from_proto(request.channel_id);
3050    let member_id = UserId::from_proto(request.user_id);
3051    let result = db
3052        .set_channel_member_role(
3053            channel_id,
3054            session.user_id(),
3055            member_id,
3056            request.role().into(),
3057        )
3058        .await?;
3059
3060    match result {
3061        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3062            let mut connection_pool = session.connection_pool().await;
3063            notify_membership_updated(
3064                &mut connection_pool,
3065                membership_update,
3066                member_id,
3067                &session.peer,
3068            )
3069        }
3070        db::SetMemberRoleResult::InviteUpdated(channel) => {
3071            let update = proto::UpdateChannels {
3072                channel_invitations: vec![channel.to_proto()],
3073                ..Default::default()
3074            };
3075
3076            for connection_id in session
3077                .connection_pool()
3078                .await
3079                .user_connection_ids(member_id)
3080            {
3081                session.peer.send(connection_id, update.clone())?;
3082            }
3083        }
3084    }
3085
3086    response.send(proto::Ack {})?;
3087    Ok(())
3088}
3089
3090/// Change the name of a channel
3091async fn rename_channel(
3092    request: proto::RenameChannel,
3093    response: Response<proto::RenameChannel>,
3094    session: Session,
3095) -> Result<()> {
3096    let db = session.db().await;
3097    let channel_id = ChannelId::from_proto(request.channel_id);
3098    let channel_model = db
3099        .rename_channel(channel_id, session.user_id(), &request.name)
3100        .await?;
3101    let root_id = channel_model.root_id();
3102    let channel = Channel::from_model(channel_model);
3103
3104    response.send(proto::RenameChannelResponse {
3105        channel: Some(channel.to_proto()),
3106    })?;
3107
3108    let connection_pool = session.connection_pool().await;
3109    let update = proto::UpdateChannels {
3110        channels: vec![channel.to_proto()],
3111        ..Default::default()
3112    };
3113    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3114        if role.can_see_channel(channel.visibility) {
3115            session.peer.send(connection_id, update.clone())?;
3116        }
3117    }
3118
3119    Ok(())
3120}
3121
3122/// Move a channel to a new parent.
3123async fn move_channel(
3124    request: proto::MoveChannel,
3125    response: Response<proto::MoveChannel>,
3126    session: Session,
3127) -> Result<()> {
3128    let channel_id = ChannelId::from_proto(request.channel_id);
3129    let to = ChannelId::from_proto(request.to);
3130
3131    let (root_id, channels) = session
3132        .db()
3133        .await
3134        .move_channel(channel_id, to, session.user_id())
3135        .await?;
3136
3137    let connection_pool = session.connection_pool().await;
3138    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3139        let channels = channels
3140            .iter()
3141            .filter_map(|channel| {
3142                if role.can_see_channel(channel.visibility) {
3143                    Some(channel.to_proto())
3144                } else {
3145                    None
3146                }
3147            })
3148            .collect::<Vec<_>>();
3149        if channels.is_empty() {
3150            continue;
3151        }
3152
3153        let update = proto::UpdateChannels {
3154            channels,
3155            ..Default::default()
3156        };
3157
3158        session.peer.send(connection_id, update.clone())?;
3159    }
3160
3161    response.send(Ack {})?;
3162    Ok(())
3163}
3164
3165/// Get the list of channel members
3166async fn get_channel_members(
3167    request: proto::GetChannelMembers,
3168    response: Response<proto::GetChannelMembers>,
3169    session: Session,
3170) -> Result<()> {
3171    let db = session.db().await;
3172    let channel_id = ChannelId::from_proto(request.channel_id);
3173    let limit = if request.limit == 0 {
3174        u16::MAX as u64
3175    } else {
3176        request.limit
3177    };
3178    let (members, users) = db
3179        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3180        .await?;
3181    response.send(proto::GetChannelMembersResponse { members, users })?;
3182    Ok(())
3183}
3184
3185/// Accept or decline a channel invitation.
3186async fn respond_to_channel_invite(
3187    request: proto::RespondToChannelInvite,
3188    response: Response<proto::RespondToChannelInvite>,
3189    session: Session,
3190) -> Result<()> {
3191    let db = session.db().await;
3192    let channel_id = ChannelId::from_proto(request.channel_id);
3193    let RespondToChannelInvite {
3194        membership_update,
3195        notifications,
3196    } = db
3197        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3198        .await?;
3199
3200    let mut connection_pool = session.connection_pool().await;
3201    if let Some(membership_update) = membership_update {
3202        notify_membership_updated(
3203            &mut connection_pool,
3204            membership_update,
3205            session.user_id(),
3206            &session.peer,
3207        );
3208    } else {
3209        let update = proto::UpdateChannels {
3210            remove_channel_invitations: vec![channel_id.to_proto()],
3211            ..Default::default()
3212        };
3213
3214        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3215            session.peer.send(connection_id, update.clone())?;
3216        }
3217    };
3218
3219    send_notifications(&connection_pool, &session.peer, notifications);
3220
3221    response.send(proto::Ack {})?;
3222
3223    Ok(())
3224}
3225
3226/// Join the channels' room
3227async fn join_channel(
3228    request: proto::JoinChannel,
3229    response: Response<proto::JoinChannel>,
3230    session: Session,
3231) -> Result<()> {
3232    let channel_id = ChannelId::from_proto(request.channel_id);
3233    join_channel_internal(channel_id, Box::new(response), session).await
3234}
3235
3236trait JoinChannelInternalResponse {
3237    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3238}
3239impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3240    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3241        Response::<proto::JoinChannel>::send(self, result)
3242    }
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3245    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246        Response::<proto::JoinRoom>::send(self, result)
3247    }
3248}
3249
3250async fn join_channel_internal(
3251    channel_id: ChannelId,
3252    response: Box<impl JoinChannelInternalResponse>,
3253    session: Session,
3254) -> Result<()> {
3255    let joined_room = {
3256        let mut db = session.db().await;
3257        // If zed quits without leaving the room, and the user re-opens zed before the
3258        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3259        // room they were in.
3260        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3261            tracing::info!(
3262                stale_connection_id = %connection,
3263                "cleaning up stale connection",
3264            );
3265            drop(db);
3266            leave_room_for_session(&session, connection).await?;
3267            db = session.db().await;
3268        }
3269
3270        let (joined_room, membership_updated, role) = db
3271            .join_channel(channel_id, session.user_id(), session.connection_id)
3272            .await?;
3273
3274        let live_kit_connection_info =
3275            session
3276                .app_state
3277                .livekit_client
3278                .as_ref()
3279                .and_then(|live_kit| {
3280                    let (can_publish, token) = if role == ChannelRole::Guest {
3281                        (
3282                            false,
3283                            live_kit
3284                                .guest_token(
3285                                    &joined_room.room.livekit_room,
3286                                    &session.user_id().to_string(),
3287                                )
3288                                .trace_err()?,
3289                        )
3290                    } else {
3291                        (
3292                            true,
3293                            live_kit
3294                                .room_token(
3295                                    &joined_room.room.livekit_room,
3296                                    &session.user_id().to_string(),
3297                                )
3298                                .trace_err()?,
3299                        )
3300                    };
3301
3302                    Some(LiveKitConnectionInfo {
3303                        server_url: live_kit.url().into(),
3304                        token,
3305                        can_publish,
3306                    })
3307                });
3308
3309        response.send(proto::JoinRoomResponse {
3310            room: Some(joined_room.room.clone()),
3311            channel_id: joined_room
3312                .channel
3313                .as_ref()
3314                .map(|channel| channel.id.to_proto()),
3315            live_kit_connection_info,
3316        })?;
3317
3318        let mut connection_pool = session.connection_pool().await;
3319        if let Some(membership_updated) = membership_updated {
3320            notify_membership_updated(
3321                &mut connection_pool,
3322                membership_updated,
3323                session.user_id(),
3324                &session.peer,
3325            );
3326        }
3327
3328        room_updated(&joined_room.room, &session.peer);
3329
3330        joined_room
3331    };
3332
3333    channel_updated(
3334        &joined_room
3335            .channel
3336            .ok_or_else(|| anyhow!("channel not returned"))?,
3337        &joined_room.room,
3338        &session.peer,
3339        &*session.connection_pool().await,
3340    );
3341
3342    update_user_contacts(session.user_id(), &session).await?;
3343    Ok(())
3344}
3345
3346/// Start editing the channel notes
3347async fn join_channel_buffer(
3348    request: proto::JoinChannelBuffer,
3349    response: Response<proto::JoinChannelBuffer>,
3350    session: Session,
3351) -> Result<()> {
3352    let db = session.db().await;
3353    let channel_id = ChannelId::from_proto(request.channel_id);
3354
3355    let open_response = db
3356        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3357        .await?;
3358
3359    let collaborators = open_response.collaborators.clone();
3360    response.send(open_response)?;
3361
3362    let update = UpdateChannelBufferCollaborators {
3363        channel_id: channel_id.to_proto(),
3364        collaborators: collaborators.clone(),
3365    };
3366    channel_buffer_updated(
3367        session.connection_id,
3368        collaborators
3369            .iter()
3370            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3371        &update,
3372        &session.peer,
3373    );
3374
3375    Ok(())
3376}
3377
3378/// Edit the channel notes
3379async fn update_channel_buffer(
3380    request: proto::UpdateChannelBuffer,
3381    session: Session,
3382) -> Result<()> {
3383    let db = session.db().await;
3384    let channel_id = ChannelId::from_proto(request.channel_id);
3385
3386    let (collaborators, epoch, version) = db
3387        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3388        .await?;
3389
3390    channel_buffer_updated(
3391        session.connection_id,
3392        collaborators.clone(),
3393        &proto::UpdateChannelBuffer {
3394            channel_id: channel_id.to_proto(),
3395            operations: request.operations,
3396        },
3397        &session.peer,
3398    );
3399
3400    let pool = &*session.connection_pool().await;
3401
3402    let non_collaborators =
3403        pool.channel_connection_ids(channel_id)
3404            .filter_map(|(connection_id, _)| {
3405                if collaborators.contains(&connection_id) {
3406                    None
3407                } else {
3408                    Some(connection_id)
3409                }
3410            });
3411
3412    broadcast(None, non_collaborators, |peer_id| {
3413        session.peer.send(
3414            peer_id,
3415            proto::UpdateChannels {
3416                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3417                    channel_id: channel_id.to_proto(),
3418                    epoch: epoch as u64,
3419                    version: version.clone(),
3420                }],
3421                ..Default::default()
3422            },
3423        )
3424    });
3425
3426    Ok(())
3427}
3428
3429/// Rejoin the channel notes after a connection blip
3430async fn rejoin_channel_buffers(
3431    request: proto::RejoinChannelBuffers,
3432    response: Response<proto::RejoinChannelBuffers>,
3433    session: Session,
3434) -> Result<()> {
3435    let db = session.db().await;
3436    let buffers = db
3437        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3438        .await?;
3439
3440    for rejoined_buffer in &buffers {
3441        let collaborators_to_notify = rejoined_buffer
3442            .buffer
3443            .collaborators
3444            .iter()
3445            .filter_map(|c| Some(c.peer_id?.into()));
3446        channel_buffer_updated(
3447            session.connection_id,
3448            collaborators_to_notify,
3449            &proto::UpdateChannelBufferCollaborators {
3450                channel_id: rejoined_buffer.buffer.channel_id,
3451                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3452            },
3453            &session.peer,
3454        );
3455    }
3456
3457    response.send(proto::RejoinChannelBuffersResponse {
3458        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3459    })?;
3460
3461    Ok(())
3462}
3463
3464/// Stop editing the channel notes
3465async fn leave_channel_buffer(
3466    request: proto::LeaveChannelBuffer,
3467    response: Response<proto::LeaveChannelBuffer>,
3468    session: Session,
3469) -> Result<()> {
3470    let db = session.db().await;
3471    let channel_id = ChannelId::from_proto(request.channel_id);
3472
3473    let left_buffer = db
3474        .leave_channel_buffer(channel_id, session.connection_id)
3475        .await?;
3476
3477    response.send(Ack {})?;
3478
3479    channel_buffer_updated(
3480        session.connection_id,
3481        left_buffer.connections,
3482        &proto::UpdateChannelBufferCollaborators {
3483            channel_id: channel_id.to_proto(),
3484            collaborators: left_buffer.collaborators,
3485        },
3486        &session.peer,
3487    );
3488
3489    Ok(())
3490}
3491
3492fn channel_buffer_updated<T: EnvelopedMessage>(
3493    sender_id: ConnectionId,
3494    collaborators: impl IntoIterator<Item = ConnectionId>,
3495    message: &T,
3496    peer: &Peer,
3497) {
3498    broadcast(Some(sender_id), collaborators, |peer_id| {
3499        peer.send(peer_id, message.clone())
3500    });
3501}
3502
3503fn send_notifications(
3504    connection_pool: &ConnectionPool,
3505    peer: &Peer,
3506    notifications: db::NotificationBatch,
3507) {
3508    for (user_id, notification) in notifications {
3509        for connection_id in connection_pool.user_connection_ids(user_id) {
3510            if let Err(error) = peer.send(
3511                connection_id,
3512                proto::AddNotification {
3513                    notification: Some(notification.clone()),
3514                },
3515            ) {
3516                tracing::error!(
3517                    "failed to send notification to {:?} {}",
3518                    connection_id,
3519                    error
3520                );
3521            }
3522        }
3523    }
3524}
3525
3526/// Send a message to the channel
3527async fn send_channel_message(
3528    request: proto::SendChannelMessage,
3529    response: Response<proto::SendChannelMessage>,
3530    session: Session,
3531) -> Result<()> {
3532    // Validate the message body.
3533    let body = request.body.trim().to_string();
3534    if body.len() > MAX_MESSAGE_LEN {
3535        return Err(anyhow!("message is too long"))?;
3536    }
3537    if body.is_empty() {
3538        return Err(anyhow!("message can't be blank"))?;
3539    }
3540
3541    // TODO: adjust mentions if body is trimmed
3542
3543    let timestamp = OffsetDateTime::now_utc();
3544    let nonce = request
3545        .nonce
3546        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3547
3548    let channel_id = ChannelId::from_proto(request.channel_id);
3549    let CreatedChannelMessage {
3550        message_id,
3551        participant_connection_ids,
3552        notifications,
3553    } = session
3554        .db()
3555        .await
3556        .create_channel_message(
3557            channel_id,
3558            session.user_id(),
3559            &body,
3560            &request.mentions,
3561            timestamp,
3562            nonce.clone().into(),
3563            request.reply_to_message_id.map(MessageId::from_proto),
3564        )
3565        .await?;
3566
3567    let message = proto::ChannelMessage {
3568        sender_id: session.user_id().to_proto(),
3569        id: message_id.to_proto(),
3570        body,
3571        mentions: request.mentions,
3572        timestamp: timestamp.unix_timestamp() as u64,
3573        nonce: Some(nonce),
3574        reply_to_message_id: request.reply_to_message_id,
3575        edited_at: None,
3576    };
3577    broadcast(
3578        Some(session.connection_id),
3579        participant_connection_ids.clone(),
3580        |connection| {
3581            session.peer.send(
3582                connection,
3583                proto::ChannelMessageSent {
3584                    channel_id: channel_id.to_proto(),
3585                    message: Some(message.clone()),
3586                },
3587            )
3588        },
3589    );
3590    response.send(proto::SendChannelMessageResponse {
3591        message: Some(message),
3592    })?;
3593
3594    let pool = &*session.connection_pool().await;
3595    let non_participants =
3596        pool.channel_connection_ids(channel_id)
3597            .filter_map(|(connection_id, _)| {
3598                if participant_connection_ids.contains(&connection_id) {
3599                    None
3600                } else {
3601                    Some(connection_id)
3602                }
3603            });
3604    broadcast(None, non_participants, |peer_id| {
3605        session.peer.send(
3606            peer_id,
3607            proto::UpdateChannels {
3608                latest_channel_message_ids: vec![proto::ChannelMessageId {
3609                    channel_id: channel_id.to_proto(),
3610                    message_id: message_id.to_proto(),
3611                }],
3612                ..Default::default()
3613            },
3614        )
3615    });
3616    send_notifications(pool, &session.peer, notifications);
3617
3618    Ok(())
3619}
3620
3621/// Delete a channel message
3622async fn remove_channel_message(
3623    request: proto::RemoveChannelMessage,
3624    response: Response<proto::RemoveChannelMessage>,
3625    session: Session,
3626) -> Result<()> {
3627    let channel_id = ChannelId::from_proto(request.channel_id);
3628    let message_id = MessageId::from_proto(request.message_id);
3629    let (connection_ids, existing_notification_ids) = session
3630        .db()
3631        .await
3632        .remove_channel_message(channel_id, message_id, session.user_id())
3633        .await?;
3634
3635    broadcast(
3636        Some(session.connection_id),
3637        connection_ids,
3638        move |connection| {
3639            session.peer.send(connection, request.clone())?;
3640
3641            for notification_id in &existing_notification_ids {
3642                session.peer.send(
3643                    connection,
3644                    proto::DeleteNotification {
3645                        notification_id: (*notification_id).to_proto(),
3646                    },
3647                )?;
3648            }
3649
3650            Ok(())
3651        },
3652    );
3653    response.send(proto::Ack {})?;
3654    Ok(())
3655}
3656
3657async fn update_channel_message(
3658    request: proto::UpdateChannelMessage,
3659    response: Response<proto::UpdateChannelMessage>,
3660    session: Session,
3661) -> Result<()> {
3662    let channel_id = ChannelId::from_proto(request.channel_id);
3663    let message_id = MessageId::from_proto(request.message_id);
3664    let updated_at = OffsetDateTime::now_utc();
3665    let UpdatedChannelMessage {
3666        message_id,
3667        participant_connection_ids,
3668        notifications,
3669        reply_to_message_id,
3670        timestamp,
3671        deleted_mention_notification_ids,
3672        updated_mention_notifications,
3673    } = session
3674        .db()
3675        .await
3676        .update_channel_message(
3677            channel_id,
3678            message_id,
3679            session.user_id(),
3680            request.body.as_str(),
3681            &request.mentions,
3682            updated_at,
3683        )
3684        .await?;
3685
3686    let nonce = request
3687        .nonce
3688        .clone()
3689        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3690
3691    let message = proto::ChannelMessage {
3692        sender_id: session.user_id().to_proto(),
3693        id: message_id.to_proto(),
3694        body: request.body.clone(),
3695        mentions: request.mentions.clone(),
3696        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3697        nonce: Some(nonce),
3698        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3699        edited_at: Some(updated_at.unix_timestamp() as u64),
3700    };
3701
3702    response.send(proto::Ack {})?;
3703
3704    let pool = &*session.connection_pool().await;
3705    broadcast(
3706        Some(session.connection_id),
3707        participant_connection_ids,
3708        |connection| {
3709            session.peer.send(
3710                connection,
3711                proto::ChannelMessageUpdate {
3712                    channel_id: channel_id.to_proto(),
3713                    message: Some(message.clone()),
3714                },
3715            )?;
3716
3717            for notification_id in &deleted_mention_notification_ids {
3718                session.peer.send(
3719                    connection,
3720                    proto::DeleteNotification {
3721                        notification_id: (*notification_id).to_proto(),
3722                    },
3723                )?;
3724            }
3725
3726            for notification in &updated_mention_notifications {
3727                session.peer.send(
3728                    connection,
3729                    proto::UpdateNotification {
3730                        notification: Some(notification.clone()),
3731                    },
3732                )?;
3733            }
3734
3735            Ok(())
3736        },
3737    );
3738
3739    send_notifications(pool, &session.peer, notifications);
3740
3741    Ok(())
3742}
3743
3744/// Mark a channel message as read
3745async fn acknowledge_channel_message(
3746    request: proto::AckChannelMessage,
3747    session: Session,
3748) -> Result<()> {
3749    let channel_id = ChannelId::from_proto(request.channel_id);
3750    let message_id = MessageId::from_proto(request.message_id);
3751    let notifications = session
3752        .db()
3753        .await
3754        .observe_channel_message(channel_id, session.user_id(), message_id)
3755        .await?;
3756    send_notifications(
3757        &*session.connection_pool().await,
3758        &session.peer,
3759        notifications,
3760    );
3761    Ok(())
3762}
3763
3764/// Mark a buffer version as synced
3765async fn acknowledge_buffer_version(
3766    request: proto::AckBufferOperation,
3767    session: Session,
3768) -> Result<()> {
3769    let buffer_id = BufferId::from_proto(request.buffer_id);
3770    session
3771        .db()
3772        .await
3773        .observe_buffer_version(
3774            buffer_id,
3775            session.user_id(),
3776            request.epoch as i32,
3777            &request.version,
3778        )
3779        .await?;
3780    Ok(())
3781}
3782
3783/// Get a Supermaven API key for the user
3784async fn get_supermaven_api_key(
3785    _request: proto::GetSupermavenApiKey,
3786    response: Response<proto::GetSupermavenApiKey>,
3787    session: Session,
3788) -> Result<()> {
3789    let user_id: String = session.user_id().to_string();
3790    if !session.is_staff() {
3791        return Err(anyhow!("supermaven not enabled for this account"))?;
3792    }
3793
3794    let email = session
3795        .email()
3796        .ok_or_else(|| anyhow!("user must have an email"))?;
3797
3798    let supermaven_admin_api = session
3799        .supermaven_client
3800        .as_ref()
3801        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3802
3803    let result = supermaven_admin_api
3804        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3805        .await?;
3806
3807    response.send(proto::GetSupermavenApiKeyResponse {
3808        api_key: result.api_key,
3809    })?;
3810
3811    Ok(())
3812}
3813
3814/// Start receiving chat updates for a channel
3815async fn join_channel_chat(
3816    request: proto::JoinChannelChat,
3817    response: Response<proto::JoinChannelChat>,
3818    session: Session,
3819) -> Result<()> {
3820    let channel_id = ChannelId::from_proto(request.channel_id);
3821
3822    let db = session.db().await;
3823    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3824        .await?;
3825    let messages = db
3826        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3827        .await?;
3828    response.send(proto::JoinChannelChatResponse {
3829        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3830        messages,
3831    })?;
3832    Ok(())
3833}
3834
3835/// Stop receiving chat updates for a channel
3836async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3837    let channel_id = ChannelId::from_proto(request.channel_id);
3838    session
3839        .db()
3840        .await
3841        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3842        .await?;
3843    Ok(())
3844}
3845
3846/// Retrieve the chat history for a channel
3847async fn get_channel_messages(
3848    request: proto::GetChannelMessages,
3849    response: Response<proto::GetChannelMessages>,
3850    session: Session,
3851) -> Result<()> {
3852    let channel_id = ChannelId::from_proto(request.channel_id);
3853    let messages = session
3854        .db()
3855        .await
3856        .get_channel_messages(
3857            channel_id,
3858            session.user_id(),
3859            MESSAGE_COUNT_PER_PAGE,
3860            Some(MessageId::from_proto(request.before_message_id)),
3861        )
3862        .await?;
3863    response.send(proto::GetChannelMessagesResponse {
3864        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3865        messages,
3866    })?;
3867    Ok(())
3868}
3869
3870/// Retrieve specific chat messages
3871async fn get_channel_messages_by_id(
3872    request: proto::GetChannelMessagesById,
3873    response: Response<proto::GetChannelMessagesById>,
3874    session: Session,
3875) -> Result<()> {
3876    let message_ids = request
3877        .message_ids
3878        .iter()
3879        .map(|id| MessageId::from_proto(*id))
3880        .collect::<Vec<_>>();
3881    let messages = session
3882        .db()
3883        .await
3884        .get_channel_messages_by_id(session.user_id(), &message_ids)
3885        .await?;
3886    response.send(proto::GetChannelMessagesResponse {
3887        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888        messages,
3889    })?;
3890    Ok(())
3891}
3892
3893/// Retrieve the current users notifications
3894async fn get_notifications(
3895    request: proto::GetNotifications,
3896    response: Response<proto::GetNotifications>,
3897    session: Session,
3898) -> Result<()> {
3899    let notifications = session
3900        .db()
3901        .await
3902        .get_notifications(
3903            session.user_id(),
3904            NOTIFICATION_COUNT_PER_PAGE,
3905            request.before_id.map(db::NotificationId::from_proto),
3906        )
3907        .await?;
3908    response.send(proto::GetNotificationsResponse {
3909        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3910        notifications,
3911    })?;
3912    Ok(())
3913}
3914
3915/// Mark notifications as read
3916async fn mark_notification_as_read(
3917    request: proto::MarkNotificationRead,
3918    response: Response<proto::MarkNotificationRead>,
3919    session: Session,
3920) -> Result<()> {
3921    let database = &session.db().await;
3922    let notifications = database
3923        .mark_notification_as_read_by_id(
3924            session.user_id(),
3925            NotificationId::from_proto(request.notification_id),
3926        )
3927        .await?;
3928    send_notifications(
3929        &*session.connection_pool().await,
3930        &session.peer,
3931        notifications,
3932    );
3933    response.send(proto::Ack {})?;
3934    Ok(())
3935}
3936
3937/// Get the current users information
3938async fn get_private_user_info(
3939    _request: proto::GetPrivateUserInfo,
3940    response: Response<proto::GetPrivateUserInfo>,
3941    session: Session,
3942) -> Result<()> {
3943    let db = session.db().await;
3944
3945    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3946    let user = db
3947        .get_user_by_id(session.user_id())
3948        .await?
3949        .ok_or_else(|| anyhow!("user not found"))?;
3950    let flags = db.get_user_flags(session.user_id()).await?;
3951
3952    response.send(proto::GetPrivateUserInfoResponse {
3953        metrics_id,
3954        staff: user.admin,
3955        flags,
3956        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3957    })?;
3958    Ok(())
3959}
3960
3961/// Accept the terms of service (tos) on behalf of the current user
3962async fn accept_terms_of_service(
3963    _request: proto::AcceptTermsOfService,
3964    response: Response<proto::AcceptTermsOfService>,
3965    session: Session,
3966) -> Result<()> {
3967    let db = session.db().await;
3968
3969    let accepted_tos_at = Utc::now();
3970    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3971        .await?;
3972
3973    response.send(proto::AcceptTermsOfServiceResponse {
3974        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3975    })?;
3976    Ok(())
3977}
3978
3979/// The minimum account age an account must have in order to use the LLM service.
3980pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3981
3982async fn get_llm_api_token(
3983    _request: proto::GetLlmToken,
3984    response: Response<proto::GetLlmToken>,
3985    session: Session,
3986) -> Result<()> {
3987    let db = session.db().await;
3988
3989    let flags = db.get_user_flags(session.user_id()).await?;
3990
3991    let user_id = session.user_id();
3992    let user = db
3993        .get_user_by_id(user_id)
3994        .await?
3995        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3996
3997    if user.accepted_tos_at.is_none() {
3998        Err(anyhow!("terms of service not accepted"))?
3999    }
4000
4001    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4002    let billing_preferences = db.get_billing_preferences(user.id).await?;
4003
4004    let token = LlmTokenClaims::create(
4005        &user,
4006        session.is_staff(),
4007        billing_preferences,
4008        &flags,
4009        billing_subscription,
4010        session.system_id.clone(),
4011        &session.app_state.config,
4012    )?;
4013    response.send(proto::GetLlmTokenResponse { token })?;
4014    Ok(())
4015}
4016
4017fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4018    let message = match message {
4019        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4020        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4021        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4022        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4023        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4024            code: frame.code.into(),
4025            reason: frame.reason.as_str().to_owned().into(),
4026        })),
4027        // We should never receive a frame while reading the message, according
4028        // to the `tungstenite` maintainers:
4029        //
4030        // > It cannot occur when you read messages from the WebSocket, but it
4031        // > can be used when you want to send the raw frames (e.g. you want to
4032        // > send the frames to the WebSocket without composing the full message first).
4033        // >
4034        // > — https://github.com/snapview/tungstenite-rs/issues/268
4035        TungsteniteMessage::Frame(_) => {
4036            bail!("received an unexpected frame while reading the message")
4037        }
4038    };
4039
4040    Ok(message)
4041}
4042
4043fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4044    match message {
4045        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4046        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4047        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4048        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4049        AxumMessage::Close(frame) => {
4050            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4051                code: frame.code.into(),
4052                reason: frame.reason.as_ref().into(),
4053            }))
4054        }
4055    }
4056}
4057
4058fn notify_membership_updated(
4059    connection_pool: &mut ConnectionPool,
4060    result: MembershipUpdated,
4061    user_id: UserId,
4062    peer: &Peer,
4063) {
4064    for membership in &result.new_channels.channel_memberships {
4065        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4066    }
4067    for channel_id in &result.removed_channels {
4068        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4069    }
4070
4071    let user_channels_update = proto::UpdateUserChannels {
4072        channel_memberships: result
4073            .new_channels
4074            .channel_memberships
4075            .iter()
4076            .map(|cm| proto::ChannelMembership {
4077                channel_id: cm.channel_id.to_proto(),
4078                role: cm.role.into(),
4079            })
4080            .collect(),
4081        ..Default::default()
4082    };
4083
4084    let mut update = build_channels_update(result.new_channels);
4085    update.delete_channels = result
4086        .removed_channels
4087        .into_iter()
4088        .map(|id| id.to_proto())
4089        .collect();
4090    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4091
4092    for connection_id in connection_pool.user_connection_ids(user_id) {
4093        peer.send(connection_id, user_channels_update.clone())
4094            .trace_err();
4095        peer.send(connection_id, update.clone()).trace_err();
4096    }
4097}
4098
4099fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4100    proto::UpdateUserChannels {
4101        channel_memberships: channels
4102            .channel_memberships
4103            .iter()
4104            .map(|m| proto::ChannelMembership {
4105                channel_id: m.channel_id.to_proto(),
4106                role: m.role.into(),
4107            })
4108            .collect(),
4109        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4110        observed_channel_message_id: channels.observed_channel_messages.clone(),
4111    }
4112}
4113
4114fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4115    let mut update = proto::UpdateChannels::default();
4116
4117    for channel in channels.channels {
4118        update.channels.push(channel.to_proto());
4119    }
4120
4121    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4122    update.latest_channel_message_ids = channels.latest_channel_messages;
4123
4124    for (channel_id, participants) in channels.channel_participants {
4125        update
4126            .channel_participants
4127            .push(proto::ChannelParticipants {
4128                channel_id: channel_id.to_proto(),
4129                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4130            });
4131    }
4132
4133    for channel in channels.invited_channels {
4134        update.channel_invitations.push(channel.to_proto());
4135    }
4136
4137    update
4138}
4139
4140fn build_initial_contacts_update(
4141    contacts: Vec<db::Contact>,
4142    pool: &ConnectionPool,
4143) -> proto::UpdateContacts {
4144    let mut update = proto::UpdateContacts::default();
4145
4146    for contact in contacts {
4147        match contact {
4148            db::Contact::Accepted { user_id, busy } => {
4149                update.contacts.push(contact_for_user(user_id, busy, pool));
4150            }
4151            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4152            db::Contact::Incoming { user_id } => {
4153                update
4154                    .incoming_requests
4155                    .push(proto::IncomingContactRequest {
4156                        requester_id: user_id.to_proto(),
4157                    })
4158            }
4159        }
4160    }
4161
4162    update
4163}
4164
4165fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4166    proto::Contact {
4167        user_id: user_id.to_proto(),
4168        online: pool.is_user_online(user_id),
4169        busy,
4170    }
4171}
4172
4173fn room_updated(room: &proto::Room, peer: &Peer) {
4174    broadcast(
4175        None,
4176        room.participants
4177            .iter()
4178            .filter_map(|participant| Some(participant.peer_id?.into())),
4179        |peer_id| {
4180            peer.send(
4181                peer_id,
4182                proto::RoomUpdated {
4183                    room: Some(room.clone()),
4184                },
4185            )
4186        },
4187    );
4188}
4189
4190fn channel_updated(
4191    channel: &db::channel::Model,
4192    room: &proto::Room,
4193    peer: &Peer,
4194    pool: &ConnectionPool,
4195) {
4196    let participants = room
4197        .participants
4198        .iter()
4199        .map(|p| p.user_id)
4200        .collect::<Vec<_>>();
4201
4202    broadcast(
4203        None,
4204        pool.channel_connection_ids(channel.root_id())
4205            .filter_map(|(channel_id, role)| {
4206                role.can_see_channel(channel.visibility)
4207                    .then_some(channel_id)
4208            }),
4209        |peer_id| {
4210            peer.send(
4211                peer_id,
4212                proto::UpdateChannels {
4213                    channel_participants: vec![proto::ChannelParticipants {
4214                        channel_id: channel.id.to_proto(),
4215                        participant_user_ids: participants.clone(),
4216                    }],
4217                    ..Default::default()
4218                },
4219            )
4220        },
4221    );
4222}
4223
4224async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4225    let db = session.db().await;
4226
4227    let contacts = db.get_contacts(user_id).await?;
4228    let busy = db.is_user_busy(user_id).await?;
4229
4230    let pool = session.connection_pool().await;
4231    let updated_contact = contact_for_user(user_id, busy, &pool);
4232    for contact in contacts {
4233        if let db::Contact::Accepted {
4234            user_id: contact_user_id,
4235            ..
4236        } = contact
4237        {
4238            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4239                session
4240                    .peer
4241                    .send(
4242                        contact_conn_id,
4243                        proto::UpdateContacts {
4244                            contacts: vec![updated_contact.clone()],
4245                            remove_contacts: Default::default(),
4246                            incoming_requests: Default::default(),
4247                            remove_incoming_requests: Default::default(),
4248                            outgoing_requests: Default::default(),
4249                            remove_outgoing_requests: Default::default(),
4250                        },
4251                    )
4252                    .trace_err();
4253            }
4254        }
4255    }
4256    Ok(())
4257}
4258
4259async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4260    let mut contacts_to_update = HashSet::default();
4261
4262    let room_id;
4263    let canceled_calls_to_user_ids;
4264    let livekit_room;
4265    let delete_livekit_room;
4266    let room;
4267    let channel;
4268
4269    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4270        contacts_to_update.insert(session.user_id());
4271
4272        for project in left_room.left_projects.values() {
4273            project_left(project, session);
4274        }
4275
4276        room_id = RoomId::from_proto(left_room.room.id);
4277        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4278        livekit_room = mem::take(&mut left_room.room.livekit_room);
4279        delete_livekit_room = left_room.deleted;
4280        room = mem::take(&mut left_room.room);
4281        channel = mem::take(&mut left_room.channel);
4282
4283        room_updated(&room, &session.peer);
4284    } else {
4285        return Ok(());
4286    }
4287
4288    if let Some(channel) = channel {
4289        channel_updated(
4290            &channel,
4291            &room,
4292            &session.peer,
4293            &*session.connection_pool().await,
4294        );
4295    }
4296
4297    {
4298        let pool = session.connection_pool().await;
4299        for canceled_user_id in canceled_calls_to_user_ids {
4300            for connection_id in pool.user_connection_ids(canceled_user_id) {
4301                session
4302                    .peer
4303                    .send(
4304                        connection_id,
4305                        proto::CallCanceled {
4306                            room_id: room_id.to_proto(),
4307                        },
4308                    )
4309                    .trace_err();
4310            }
4311            contacts_to_update.insert(canceled_user_id);
4312        }
4313    }
4314
4315    for contact_user_id in contacts_to_update {
4316        update_user_contacts(contact_user_id, session).await?;
4317    }
4318
4319    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4320        live_kit
4321            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4322            .await
4323            .trace_err();
4324
4325        if delete_livekit_room {
4326            live_kit.delete_room(livekit_room).await.trace_err();
4327        }
4328    }
4329
4330    Ok(())
4331}
4332
4333async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4334    let left_channel_buffers = session
4335        .db()
4336        .await
4337        .leave_channel_buffers(session.connection_id)
4338        .await?;
4339
4340    for left_buffer in left_channel_buffers {
4341        channel_buffer_updated(
4342            session.connection_id,
4343            left_buffer.connections,
4344            &proto::UpdateChannelBufferCollaborators {
4345                channel_id: left_buffer.channel_id.to_proto(),
4346                collaborators: left_buffer.collaborators,
4347            },
4348            &session.peer,
4349        );
4350    }
4351
4352    Ok(())
4353}
4354
4355fn project_left(project: &db::LeftProject, session: &Session) {
4356    for connection_id in &project.connection_ids {
4357        if project.should_unshare {
4358            session
4359                .peer
4360                .send(
4361                    *connection_id,
4362                    proto::UnshareProject {
4363                        project_id: project.id.to_proto(),
4364                    },
4365                )
4366                .trace_err();
4367        } else {
4368            session
4369                .peer
4370                .send(
4371                    *connection_id,
4372                    proto::RemoveProjectCollaborator {
4373                        project_id: project.id.to_proto(),
4374                        peer_id: Some(session.connection_id.into()),
4375                    },
4376                )
4377                .trace_err();
4378        }
4379    }
4380}
4381
4382pub trait ResultExt {
4383    type Ok;
4384
4385    fn trace_err(self) -> Option<Self::Ok>;
4386}
4387
4388impl<T, E> ResultExt for Result<T, E>
4389where
4390    E: std::fmt::Debug,
4391{
4392    type Ok = T;
4393
4394    #[track_caller]
4395    fn trace_err(self) -> Option<T> {
4396        match self {
4397            Ok(value) => Some(value),
4398            Err(error) => {
4399                tracing::error!("{:?}", error);
4400                None
4401            }
4402        }
4403    }
4404}