rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::LlmTokenClaims;
   6use crate::{
   7    AppState, Error, Result, auth,
   8    db::{
   9        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  10        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  11        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  12        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15};
  16use anyhow::{Context as _, anyhow, bail};
  17use async_tungstenite::tungstenite::{
  18    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  19};
  20use axum::{
  21    Extension, Router, TypedHeader,
  22    body::Body,
  23    extract::{
  24        ConnectInfo, WebSocketUpgrade,
  25        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{MutexGuard, Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone, Debug)]
 105pub enum Principal {
 106    User(User),
 107    Impersonated { user: User, admin: User },
 108}
 109
 110impl Principal {
 111    fn update_span(&self, span: &tracing::Span) {
 112        match &self {
 113            Principal::User(user) => {
 114                span.record("user_id", user.id.0);
 115                span.record("login", &user.github_login);
 116            }
 117            Principal::Impersonated { user, admin } => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120                span.record("impersonator", &admin.github_login);
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Clone)]
 127struct Session {
 128    principal: Principal,
 129    connection_id: ConnectionId,
 130    db: Arc<tokio::sync::Mutex<DbHandle>>,
 131    peer: Arc<Peer>,
 132    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 133    app_state: Arc<AppState>,
 134    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 135    /// The GeoIP country code for the user.
 136    #[allow(unused)]
 137    geoip_country_code: Option<String>,
 138    system_id: Option<String>,
 139    _executor: Executor,
 140}
 141
 142impl Session {
 143    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 144        #[cfg(test)]
 145        tokio::task::yield_now().await;
 146        let guard = self.db.lock().await;
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        guard
 150    }
 151
 152    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.connection_pool.lock();
 156        ConnectionPoolGuard {
 157            guard,
 158            _not_send: PhantomData,
 159        }
 160    }
 161
 162    fn is_staff(&self) -> bool {
 163        match &self.principal {
 164            Principal::User(user) => user.admin,
 165            Principal::Impersonated { .. } => true,
 166        }
 167    }
 168
 169    pub async fn has_llm_subscription(
 170        &self,
 171        db: &MutexGuard<'_, DbHandle>,
 172    ) -> anyhow::Result<bool> {
 173        if self.is_staff() {
 174            return Ok(true);
 175        }
 176
 177        let user_id = self.user_id();
 178
 179        Ok(db.has_active_billing_subscription(user_id).await?)
 180    }
 181
 182    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 183        let user_id = self.user_id();
 184
 185        let subscription = db.get_active_billing_subscription(user_id).await?;
 186        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
 187
 188        let plan = if let Some(subscription_kind) = subscription_kind {
 189            match subscription_kind {
 190                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
 191                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
 192                SubscriptionKind::ZedFree => proto::Plan::Free,
 193            }
 194        } else {
 195            proto::Plan::Free
 196        };
 197
 198        Ok(plan)
 199    }
 200
 201    fn user_id(&self) -> UserId {
 202        match &self.principal {
 203            Principal::User(user) => user.id,
 204            Principal::Impersonated { user, .. } => user.id,
 205        }
 206    }
 207
 208    pub fn email(&self) -> Option<String> {
 209        match &self.principal {
 210            Principal::User(user) => user.email_address.clone(),
 211            Principal::Impersonated { user, .. } => user.email_address.clone(),
 212        }
 213    }
 214}
 215
 216impl Debug for Session {
 217    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 218        let mut result = f.debug_struct("Session");
 219        match &self.principal {
 220            Principal::User(user) => {
 221                result.field("user", &user.github_login);
 222            }
 223            Principal::Impersonated { user, admin } => {
 224                result.field("user", &user.github_login);
 225                result.field("impersonator", &admin.github_login);
 226            }
 227        }
 228        result.field("connection_id", &self.connection_id).finish()
 229    }
 230}
 231
 232struct DbHandle(Arc<Database>);
 233
 234impl Deref for DbHandle {
 235    type Target = Database;
 236
 237    fn deref(&self) -> &Self::Target {
 238        self.0.as_ref()
 239    }
 240}
 241
 242pub struct Server {
 243    id: parking_lot::Mutex<ServerId>,
 244    peer: Arc<Peer>,
 245    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 246    app_state: Arc<AppState>,
 247    handlers: HashMap<TypeId, MessageHandler>,
 248    teardown: watch::Sender<bool>,
 249}
 250
 251pub(crate) struct ConnectionPoolGuard<'a> {
 252    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 253    _not_send: PhantomData<Rc<()>>,
 254}
 255
 256#[derive(Serialize)]
 257pub struct ServerSnapshot<'a> {
 258    peer: &'a Peer,
 259    #[serde(serialize_with = "serialize_deref")]
 260    connection_pool: ConnectionPoolGuard<'a>,
 261}
 262
 263pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 264where
 265    S: Serializer,
 266    T: Deref<Target = U>,
 267    U: Serialize,
 268{
 269    Serialize::serialize(value.deref(), serializer)
 270}
 271
 272impl Server {
 273    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 274        let mut server = Self {
 275            id: parking_lot::Mutex::new(id),
 276            peer: Peer::new(id.0 as u32),
 277            app_state: app_state.clone(),
 278            connection_pool: Default::default(),
 279            handlers: Default::default(),
 280            teardown: watch::channel(false).0,
 281        };
 282
 283        server
 284            .add_request_handler(ping)
 285            .add_request_handler(create_room)
 286            .add_request_handler(join_room)
 287            .add_request_handler(rejoin_room)
 288            .add_request_handler(leave_room)
 289            .add_request_handler(set_room_participant_role)
 290            .add_request_handler(call)
 291            .add_request_handler(cancel_call)
 292            .add_message_handler(decline_call)
 293            .add_request_handler(update_participant_location)
 294            .add_request_handler(share_project)
 295            .add_message_handler(unshare_project)
 296            .add_request_handler(join_project)
 297            .add_message_handler(leave_project)
 298            .add_request_handler(update_project)
 299            .add_request_handler(update_worktree)
 300            .add_request_handler(update_repository)
 301            .add_request_handler(remove_repository)
 302            .add_message_handler(start_language_server)
 303            .add_message_handler(update_language_server)
 304            .add_message_handler(update_diagnostic_summary)
 305            .add_message_handler(update_worktree_settings)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 309            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 310            .add_request_handler(forward_find_search_candidates_request)
 311            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 316            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 317            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 318            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 320            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 326            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 327            .add_request_handler(
 328                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 329            )
 330            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 331            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 332            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 334            .add_request_handler(
 335                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 336            )
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 341            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 342            .add_request_handler(
 343                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 344            )
 345            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 346            .add_request_handler(
 347                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 348            )
 349            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 351            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 352            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 353            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 354            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 355            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 356            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 357            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 358            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 359            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 361            .add_request_handler(
 362                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 363            )
 364            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 365            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 366            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 367            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 368            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 369            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 370            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 371            .add_message_handler(create_buffer_for_peer)
 372            .add_request_handler(update_buffer)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 378            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 379            .add_request_handler(get_users)
 380            .add_request_handler(fuzzy_search_users)
 381            .add_request_handler(request_contact)
 382            .add_request_handler(remove_contact)
 383            .add_request_handler(respond_to_contact_request)
 384            .add_message_handler(subscribe_to_channels)
 385            .add_request_handler(create_channel)
 386            .add_request_handler(delete_channel)
 387            .add_request_handler(invite_channel_member)
 388            .add_request_handler(remove_channel_member)
 389            .add_request_handler(set_channel_member_role)
 390            .add_request_handler(set_channel_visibility)
 391            .add_request_handler(rename_channel)
 392            .add_request_handler(join_channel_buffer)
 393            .add_request_handler(leave_channel_buffer)
 394            .add_message_handler(update_channel_buffer)
 395            .add_request_handler(rejoin_channel_buffers)
 396            .add_request_handler(get_channel_members)
 397            .add_request_handler(respond_to_channel_invite)
 398            .add_request_handler(join_channel)
 399            .add_request_handler(join_channel_chat)
 400            .add_message_handler(leave_channel_chat)
 401            .add_request_handler(send_channel_message)
 402            .add_request_handler(remove_channel_message)
 403            .add_request_handler(update_channel_message)
 404            .add_request_handler(get_channel_messages)
 405            .add_request_handler(get_channel_messages_by_id)
 406            .add_request_handler(get_notifications)
 407            .add_request_handler(mark_notification_as_read)
 408            .add_request_handler(move_channel)
 409            .add_request_handler(follow)
 410            .add_message_handler(unfollow)
 411            .add_message_handler(update_followers)
 412            .add_request_handler(get_private_user_info)
 413            .add_request_handler(get_llm_api_token)
 414            .add_request_handler(accept_terms_of_service)
 415            .add_message_handler(acknowledge_channel_message)
 416            .add_message_handler(acknowledge_buffer_version)
 417            .add_request_handler(get_supermaven_api_key)
 418            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 419            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 420            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 421            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 422            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 423            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 424            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 426            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 427            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 428            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 429            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 430            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 431            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 432            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 433            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 437            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 438            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 439            .add_message_handler(update_context);
 440
 441        Arc::new(server)
 442    }
 443
 444    pub async fn start(&self) -> Result<()> {
 445        let server_id = *self.id.lock();
 446        let app_state = self.app_state.clone();
 447        let peer = self.peer.clone();
 448        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 449        let pool = self.connection_pool.clone();
 450        let livekit_client = self.app_state.livekit_client.clone();
 451
 452        let span = info_span!("start server");
 453        self.app_state.executor.spawn_detached(
 454            async move {
 455                tracing::info!("waiting for cleanup timeout");
 456                timeout.await;
 457                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 458                if let Some((room_ids, channel_ids)) = app_state
 459                    .db
 460                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 461                    .await
 462                    .trace_err()
 463                {
 464                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 465                    tracing::info!(
 466                        stale_channel_buffer_count = channel_ids.len(),
 467                        "retrieved stale channel buffers"
 468                    );
 469
 470                    for channel_id in channel_ids {
 471                        if let Some(refreshed_channel_buffer) = app_state
 472                            .db
 473                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 474                            .await
 475                            .trace_err()
 476                        {
 477                            for connection_id in refreshed_channel_buffer.connection_ids {
 478                                peer.send(
 479                                    connection_id,
 480                                    proto::UpdateChannelBufferCollaborators {
 481                                        channel_id: channel_id.to_proto(),
 482                                        collaborators: refreshed_channel_buffer
 483                                            .collaborators
 484                                            .clone(),
 485                                    },
 486                                )
 487                                .trace_err();
 488                            }
 489                        }
 490                    }
 491
 492                    for room_id in room_ids {
 493                        let mut contacts_to_update = HashSet::default();
 494                        let mut canceled_calls_to_user_ids = Vec::new();
 495                        let mut livekit_room = String::new();
 496                        let mut delete_livekit_room = false;
 497
 498                        if let Some(mut refreshed_room) = app_state
 499                            .db
 500                            .clear_stale_room_participants(room_id, server_id)
 501                            .await
 502                            .trace_err()
 503                        {
 504                            tracing::info!(
 505                                room_id = room_id.0,
 506                                new_participant_count = refreshed_room.room.participants.len(),
 507                                "refreshed room"
 508                            );
 509                            room_updated(&refreshed_room.room, &peer);
 510                            if let Some(channel) = refreshed_room.channel.as_ref() {
 511                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 512                            }
 513                            contacts_to_update
 514                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 515                            contacts_to_update
 516                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 517                            canceled_calls_to_user_ids =
 518                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 519                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 520                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 521                        }
 522
 523                        {
 524                            let pool = pool.lock();
 525                            for canceled_user_id in canceled_calls_to_user_ids {
 526                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 527                                    peer.send(
 528                                        connection_id,
 529                                        proto::CallCanceled {
 530                                            room_id: room_id.to_proto(),
 531                                        },
 532                                    )
 533                                    .trace_err();
 534                                }
 535                            }
 536                        }
 537
 538                        for user_id in contacts_to_update {
 539                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 540                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 541                            if let Some((busy, contacts)) = busy.zip(contacts) {
 542                                let pool = pool.lock();
 543                                let updated_contact = contact_for_user(user_id, busy, &pool);
 544                                for contact in contacts {
 545                                    if let db::Contact::Accepted {
 546                                        user_id: contact_user_id,
 547                                        ..
 548                                    } = contact
 549                                    {
 550                                        for contact_conn_id in
 551                                            pool.user_connection_ids(contact_user_id)
 552                                        {
 553                                            peer.send(
 554                                                contact_conn_id,
 555                                                proto::UpdateContacts {
 556                                                    contacts: vec![updated_contact.clone()],
 557                                                    remove_contacts: Default::default(),
 558                                                    incoming_requests: Default::default(),
 559                                                    remove_incoming_requests: Default::default(),
 560                                                    outgoing_requests: Default::default(),
 561                                                    remove_outgoing_requests: Default::default(),
 562                                                },
 563                                            )
 564                                            .trace_err();
 565                                        }
 566                                    }
 567                                }
 568                            }
 569                        }
 570
 571                        if let Some(live_kit) = livekit_client.as_ref() {
 572                            if delete_livekit_room {
 573                                live_kit.delete_room(livekit_room).await.trace_err();
 574                            }
 575                        }
 576                    }
 577                }
 578
 579                app_state
 580                    .db
 581                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 582                    .await
 583                    .trace_err();
 584            }
 585            .instrument(span),
 586        );
 587        Ok(())
 588    }
 589
 590    pub fn teardown(&self) {
 591        self.peer.teardown();
 592        self.connection_pool.lock().reset();
 593        let _ = self.teardown.send(true);
 594    }
 595
 596    #[cfg(test)]
 597    pub fn reset(&self, id: ServerId) {
 598        self.teardown();
 599        *self.id.lock() = id;
 600        self.peer.reset(id.0 as u32);
 601        let _ = self.teardown.send(false);
 602    }
 603
 604    #[cfg(test)]
 605    pub fn id(&self) -> ServerId {
 606        *self.id.lock()
 607    }
 608
 609    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 610    where
 611        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 612        Fut: 'static + Send + Future<Output = Result<()>>,
 613        M: EnvelopedMessage,
 614    {
 615        let prev_handler = self.handlers.insert(
 616            TypeId::of::<M>(),
 617            Box::new(move |envelope, session| {
 618                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 619                let received_at = envelope.received_at;
 620                tracing::info!("message received");
 621                let start_time = Instant::now();
 622                let future = (handler)(*envelope, session);
 623                async move {
 624                    let result = future.await;
 625                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 626                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 627                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 628                    let payload_type = M::NAME;
 629
 630                    match result {
 631                        Err(error) => {
 632                            tracing::error!(
 633                                ?error,
 634                                total_duration_ms,
 635                                processing_duration_ms,
 636                                queue_duration_ms,
 637                                payload_type,
 638                                "error handling message"
 639                            )
 640                        }
 641                        Ok(()) => tracing::info!(
 642                            total_duration_ms,
 643                            processing_duration_ms,
 644                            queue_duration_ms,
 645                            "finished handling message"
 646                        ),
 647                    }
 648                }
 649                .boxed()
 650            }),
 651        );
 652        if prev_handler.is_some() {
 653            panic!("registered a handler for the same message twice");
 654        }
 655        self
 656    }
 657
 658    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 659    where
 660        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 661        Fut: 'static + Send + Future<Output = Result<()>>,
 662        M: EnvelopedMessage,
 663    {
 664        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 665        self
 666    }
 667
 668    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 669    where
 670        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 671        Fut: Send + Future<Output = Result<()>>,
 672        M: RequestMessage,
 673    {
 674        let handler = Arc::new(handler);
 675        self.add_handler(move |envelope, session| {
 676            let receipt = envelope.receipt();
 677            let handler = handler.clone();
 678            async move {
 679                let peer = session.peer.clone();
 680                let responded = Arc::new(AtomicBool::default());
 681                let response = Response {
 682                    peer: peer.clone(),
 683                    responded: responded.clone(),
 684                    receipt,
 685                };
 686                match (handler)(envelope.payload, response, session).await {
 687                    Ok(()) => {
 688                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 689                            Ok(())
 690                        } else {
 691                            Err(anyhow!("handler did not send a response"))?
 692                        }
 693                    }
 694                    Err(error) => {
 695                        let proto_err = match &error {
 696                            Error::Internal(err) => err.to_proto(),
 697                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 698                        };
 699                        peer.respond_with_error(receipt, proto_err)?;
 700                        Err(error)
 701                    }
 702                }
 703            }
 704        })
 705    }
 706
 707    pub fn handle_connection(
 708        self: &Arc<Self>,
 709        connection: Connection,
 710        address: String,
 711        principal: Principal,
 712        zed_version: ZedVersion,
 713        geoip_country_code: Option<String>,
 714        system_id: Option<String>,
 715        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 716        executor: Executor,
 717    ) -> impl Future<Output = ()> + use<> {
 718        let this = self.clone();
 719        let span = info_span!("handle connection", %address,
 720            connection_id=field::Empty,
 721            user_id=field::Empty,
 722            login=field::Empty,
 723            impersonator=field::Empty,
 724            geoip_country_code=field::Empty
 725        );
 726        principal.update_span(&span);
 727        if let Some(country_code) = geoip_country_code.as_ref() {
 728            span.record("geoip_country_code", country_code);
 729        }
 730
 731        let mut teardown = self.teardown.subscribe();
 732        async move {
 733            if *teardown.borrow() {
 734                tracing::error!("server is tearing down");
 735                return
 736            }
 737            let (connection_id, handle_io, mut incoming_rx) = this
 738                .peer
 739                .add_connection(connection, {
 740                    let executor = executor.clone();
 741                    move |duration| executor.sleep(duration)
 742                });
 743            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 744
 745            tracing::info!("connection opened");
 746
 747            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 748            let http_client = match ReqwestClient::user_agent(&user_agent) {
 749                Ok(http_client) => Arc::new(http_client),
 750                Err(error) => {
 751                    tracing::error!(?error, "failed to create HTTP client");
 752                    return;
 753                }
 754            };
 755
 756            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 757                    supermaven_admin_api_key.to_string(),
 758                    http_client.clone(),
 759                )));
 760
 761            let session = Session {
 762                principal: principal.clone(),
 763                connection_id,
 764                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 765                peer: this.peer.clone(),
 766                connection_pool: this.connection_pool.clone(),
 767                app_state: this.app_state.clone(),
 768                geoip_country_code,
 769                system_id,
 770                _executor: executor.clone(),
 771                supermaven_client,
 772            };
 773
 774            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 775                tracing::error!(?error, "failed to send initial client update");
 776                return;
 777            }
 778
 779            let handle_io = handle_io.fuse();
 780            futures::pin_mut!(handle_io);
 781
 782            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 783            // This prevents deadlocks when e.g., client A performs a request to client B and
 784            // client B performs a request to client A. If both clients stop processing further
 785            // messages until their respective request completes, they won't have a chance to
 786            // respond to the other client's request and cause a deadlock.
 787            //
 788            // This arrangement ensures we will attempt to process earlier messages first, but fall
 789            // back to processing messages arrived later in the spirit of making progress.
 790            let mut foreground_message_handlers = FuturesUnordered::new();
 791            let concurrent_handlers = Arc::new(Semaphore::new(256));
 792            loop {
 793                let next_message = async {
 794                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 795                    let message = incoming_rx.next().await;
 796                    (permit, message)
 797                }.fuse();
 798                futures::pin_mut!(next_message);
 799                futures::select_biased! {
 800                    _ = teardown.changed().fuse() => return,
 801                    result = handle_io => {
 802                        if let Err(error) = result {
 803                            tracing::error!(?error, "error handling I/O");
 804                        }
 805                        break;
 806                    }
 807                    _ = foreground_message_handlers.next() => {}
 808                    next_message = next_message => {
 809                        let (permit, message) = next_message;
 810                        if let Some(message) = message {
 811                            let type_name = message.payload_type_name();
 812                            // note: we copy all the fields from the parent span so we can query them in the logs.
 813                            // (https://github.com/tokio-rs/tracing/issues/2670).
 814                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 815                                user_id=field::Empty,
 816                                login=field::Empty,
 817                                impersonator=field::Empty,
 818                            );
 819                            principal.update_span(&span);
 820                            let span_enter = span.enter();
 821                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 822                                let is_background = message.is_background();
 823                                let handle_message = (handler)(message, session.clone());
 824                                drop(span_enter);
 825
 826                                let handle_message = async move {
 827                                    handle_message.await;
 828                                    drop(permit);
 829                                }.instrument(span);
 830                                if is_background {
 831                                    executor.spawn_detached(handle_message);
 832                                } else {
 833                                    foreground_message_handlers.push(handle_message);
 834                                }
 835                            } else {
 836                                tracing::error!("no message handler");
 837                            }
 838                        } else {
 839                            tracing::info!("connection closed");
 840                            break;
 841                        }
 842                    }
 843                }
 844            }
 845
 846            drop(foreground_message_handlers);
 847            tracing::info!("signing out");
 848            if let Err(error) = connection_lost(session, teardown, executor).await {
 849                tracing::error!(?error, "error signing out");
 850            }
 851
 852        }.instrument(span)
 853    }
 854
 855    async fn send_initial_client_update(
 856        &self,
 857        connection_id: ConnectionId,
 858        principal: &Principal,
 859        zed_version: ZedVersion,
 860        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 861        session: &Session,
 862    ) -> Result<()> {
 863        self.peer.send(
 864            connection_id,
 865            proto::Hello {
 866                peer_id: Some(connection_id.into()),
 867            },
 868        )?;
 869        tracing::info!("sent hello message");
 870        if let Some(send_connection_id) = send_connection_id.take() {
 871            let _ = send_connection_id.send(connection_id);
 872        }
 873
 874        match principal {
 875            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 876                if !user.connected_once {
 877                    self.peer.send(connection_id, proto::ShowContacts {})?;
 878                    self.app_state
 879                        .db
 880                        .set_user_connected_once(user.id, true)
 881                        .await?;
 882                }
 883
 884                update_user_plan(user.id, session).await?;
 885
 886                let contacts = self.app_state.db.get_contacts(user.id).await?;
 887
 888                {
 889                    let mut pool = self.connection_pool.lock();
 890                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 891                    self.peer.send(
 892                        connection_id,
 893                        build_initial_contacts_update(contacts, &pool),
 894                    )?;
 895                }
 896
 897                if should_auto_subscribe_to_channels(zed_version) {
 898                    subscribe_user_to_channels(user.id, session).await?;
 899                }
 900
 901                if let Some(incoming_call) =
 902                    self.app_state.db.incoming_call_for_user(user.id).await?
 903                {
 904                    self.peer.send(connection_id, incoming_call)?;
 905                }
 906
 907                update_user_contacts(user.id, session).await?;
 908            }
 909        }
 910
 911        Ok(())
 912    }
 913
 914    pub async fn invite_code_redeemed(
 915        self: &Arc<Self>,
 916        inviter_id: UserId,
 917        invitee_id: UserId,
 918    ) -> Result<()> {
 919        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 920            if let Some(code) = &user.invite_code {
 921                let pool = self.connection_pool.lock();
 922                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 923                for connection_id in pool.user_connection_ids(inviter_id) {
 924                    self.peer.send(
 925                        connection_id,
 926                        proto::UpdateContacts {
 927                            contacts: vec![invitee_contact.clone()],
 928                            ..Default::default()
 929                        },
 930                    )?;
 931                    self.peer.send(
 932                        connection_id,
 933                        proto::UpdateInviteInfo {
 934                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 935                            count: user.invite_count as u32,
 936                        },
 937                    )?;
 938                }
 939            }
 940        }
 941        Ok(())
 942    }
 943
 944    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 945        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 946            if let Some(invite_code) = &user.invite_code {
 947                let pool = self.connection_pool.lock();
 948                for connection_id in pool.user_connection_ids(user_id) {
 949                    self.peer.send(
 950                        connection_id,
 951                        proto::UpdateInviteInfo {
 952                            url: format!(
 953                                "{}{}",
 954                                self.app_state.config.invite_link_prefix, invite_code
 955                            ),
 956                            count: user.invite_count as u32,
 957                        },
 958                    )?;
 959                }
 960            }
 961        }
 962        Ok(())
 963    }
 964
 965    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 966        let pool = self.connection_pool.lock();
 967        for connection_id in pool.user_connection_ids(user_id) {
 968            self.peer
 969                .send(connection_id, proto::RefreshLlmToken {})
 970                .trace_err();
 971        }
 972    }
 973
 974    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 975        ServerSnapshot {
 976            connection_pool: ConnectionPoolGuard {
 977                guard: self.connection_pool.lock(),
 978                _not_send: PhantomData,
 979            },
 980            peer: &self.peer,
 981        }
 982    }
 983}
 984
 985impl Deref for ConnectionPoolGuard<'_> {
 986    type Target = ConnectionPool;
 987
 988    fn deref(&self) -> &Self::Target {
 989        &self.guard
 990    }
 991}
 992
 993impl DerefMut for ConnectionPoolGuard<'_> {
 994    fn deref_mut(&mut self) -> &mut Self::Target {
 995        &mut self.guard
 996    }
 997}
 998
 999impl Drop for ConnectionPoolGuard<'_> {
1000    fn drop(&mut self) {
1001        #[cfg(test)]
1002        self.check_invariants();
1003    }
1004}
1005
1006fn broadcast<F>(
1007    sender_id: Option<ConnectionId>,
1008    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1009    mut f: F,
1010) where
1011    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1012{
1013    for receiver_id in receiver_ids {
1014        if Some(receiver_id) != sender_id {
1015            if let Err(error) = f(receiver_id) {
1016                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1017            }
1018        }
1019    }
1020}
1021
1022pub struct ProtocolVersion(u32);
1023
1024impl Header for ProtocolVersion {
1025    fn name() -> &'static HeaderName {
1026        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1027        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1028    }
1029
1030    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1031    where
1032        Self: Sized,
1033        I: Iterator<Item = &'i axum::http::HeaderValue>,
1034    {
1035        let version = values
1036            .next()
1037            .ok_or_else(axum::headers::Error::invalid)?
1038            .to_str()
1039            .map_err(|_| axum::headers::Error::invalid())?
1040            .parse()
1041            .map_err(|_| axum::headers::Error::invalid())?;
1042        Ok(Self(version))
1043    }
1044
1045    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1046        values.extend([self.0.to_string().parse().unwrap()]);
1047    }
1048}
1049
1050pub struct AppVersionHeader(SemanticVersion);
1051impl Header for AppVersionHeader {
1052    fn name() -> &'static HeaderName {
1053        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1054        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1055    }
1056
1057    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1058    where
1059        Self: Sized,
1060        I: Iterator<Item = &'i axum::http::HeaderValue>,
1061    {
1062        let version = values
1063            .next()
1064            .ok_or_else(axum::headers::Error::invalid)?
1065            .to_str()
1066            .map_err(|_| axum::headers::Error::invalid())?
1067            .parse()
1068            .map_err(|_| axum::headers::Error::invalid())?;
1069        Ok(Self(version))
1070    }
1071
1072    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1073        values.extend([self.0.to_string().parse().unwrap()]);
1074    }
1075}
1076
1077pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1078    Router::new()
1079        .route("/rpc", get(handle_websocket_request))
1080        .layer(
1081            ServiceBuilder::new()
1082                .layer(Extension(server.app_state.clone()))
1083                .layer(middleware::from_fn(auth::validate_header)),
1084        )
1085        .route("/metrics", get(handle_metrics))
1086        .layer(Extension(server))
1087}
1088
1089pub async fn handle_websocket_request(
1090    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1091    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1092    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1093    Extension(server): Extension<Arc<Server>>,
1094    Extension(principal): Extension<Principal>,
1095    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1096    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1097    ws: WebSocketUpgrade,
1098) -> axum::response::Response {
1099    if protocol_version != rpc::PROTOCOL_VERSION {
1100        return (
1101            StatusCode::UPGRADE_REQUIRED,
1102            "client must be upgraded".to_string(),
1103        )
1104            .into_response();
1105    }
1106
1107    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "no version header found".to_string(),
1111        )
1112            .into_response();
1113    };
1114
1115    if !version.can_collaborate() {
1116        return (
1117            StatusCode::UPGRADE_REQUIRED,
1118            "client must be upgraded".to_string(),
1119        )
1120            .into_response();
1121    }
1122
1123    let socket_address = socket_address.to_string();
1124    ws.on_upgrade(move |socket| {
1125        let socket = socket
1126            .map_ok(to_tungstenite_message)
1127            .err_into()
1128            .with(|message| async move { to_axum_message(message) });
1129        let connection = Connection::new(Box::pin(socket));
1130        async move {
1131            server
1132                .handle_connection(
1133                    connection,
1134                    socket_address,
1135                    principal,
1136                    version,
1137                    country_code_header.map(|header| header.to_string()),
1138                    system_id_header.map(|header| header.to_string()),
1139                    None,
1140                    Executor::Production,
1141                )
1142                .await;
1143        }
1144    })
1145}
1146
1147pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1148    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1149    let connections_metric = CONNECTIONS_METRIC
1150        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1151
1152    let connections = server
1153        .connection_pool
1154        .lock()
1155        .connections()
1156        .filter(|connection| !connection.admin)
1157        .count();
1158    connections_metric.set(connections as _);
1159
1160    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1161    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1162        register_int_gauge!(
1163            "shared_projects",
1164            "number of open projects with one or more guests"
1165        )
1166        .unwrap()
1167    });
1168
1169    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1170    shared_projects_metric.set(shared_projects as _);
1171
1172    let encoder = prometheus::TextEncoder::new();
1173    let metric_families = prometheus::gather();
1174    let encoded_metrics = encoder
1175        .encode_to_string(&metric_families)
1176        .map_err(|err| anyhow!("{}", err))?;
1177    Ok(encoded_metrics)
1178}
1179
1180#[instrument(err, skip(executor))]
1181async fn connection_lost(
1182    session: Session,
1183    mut teardown: watch::Receiver<bool>,
1184    executor: Executor,
1185) -> Result<()> {
1186    session.peer.disconnect(session.connection_id);
1187    session
1188        .connection_pool()
1189        .await
1190        .remove_connection(session.connection_id)?;
1191
1192    session
1193        .db()
1194        .await
1195        .connection_lost(session.connection_id)
1196        .await
1197        .trace_err();
1198
1199    futures::select_biased! {
1200        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1201
1202            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1203            leave_room_for_session(&session, session.connection_id).await.trace_err();
1204            leave_channel_buffers_for_session(&session)
1205                .await
1206                .trace_err();
1207
1208            if !session
1209                .connection_pool()
1210                .await
1211                .is_user_online(session.user_id())
1212            {
1213                let db = session.db().await;
1214                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1215                    room_updated(&room, &session.peer);
1216                }
1217            }
1218
1219            update_user_contacts(session.user_id(), &session).await?;
1220        },
1221        _ = teardown.changed().fuse() => {}
1222    }
1223
1224    Ok(())
1225}
1226
1227/// Acknowledges a ping from a client, used to keep the connection alive.
1228async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1229    response.send(proto::Ack {})?;
1230    Ok(())
1231}
1232
1233/// Creates a new room for calling (outside of channels)
1234async fn create_room(
1235    _request: proto::CreateRoom,
1236    response: Response<proto::CreateRoom>,
1237    session: Session,
1238) -> Result<()> {
1239    let livekit_room = nanoid::nanoid!(30);
1240
1241    let live_kit_connection_info = util::maybe!(async {
1242        let live_kit = session.app_state.livekit_client.as_ref();
1243        let live_kit = live_kit?;
1244        let user_id = session.user_id().to_string();
1245
1246        let token = live_kit
1247            .room_token(&livekit_room, &user_id.to_string())
1248            .trace_err()?;
1249
1250        Some(proto::LiveKitConnectionInfo {
1251            server_url: live_kit.url().into(),
1252            token,
1253            can_publish: true,
1254        })
1255    })
1256    .await;
1257
1258    let room = session
1259        .db()
1260        .await
1261        .create_room(session.user_id(), session.connection_id, &livekit_room)
1262        .await?;
1263
1264    response.send(proto::CreateRoomResponse {
1265        room: Some(room.clone()),
1266        live_kit_connection_info,
1267    })?;
1268
1269    update_user_contacts(session.user_id(), &session).await?;
1270    Ok(())
1271}
1272
1273/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1274async fn join_room(
1275    request: proto::JoinRoom,
1276    response: Response<proto::JoinRoom>,
1277    session: Session,
1278) -> Result<()> {
1279    let room_id = RoomId::from_proto(request.id);
1280
1281    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1282
1283    if let Some(channel_id) = channel_id {
1284        return join_channel_internal(channel_id, Box::new(response), session).await;
1285    }
1286
1287    let joined_room = {
1288        let room = session
1289            .db()
1290            .await
1291            .join_room(room_id, session.user_id(), session.connection_id)
1292            .await?;
1293        room_updated(&room.room, &session.peer);
1294        room.into_inner()
1295    };
1296
1297    for connection_id in session
1298        .connection_pool()
1299        .await
1300        .user_connection_ids(session.user_id())
1301    {
1302        session
1303            .peer
1304            .send(
1305                connection_id,
1306                proto::CallCanceled {
1307                    room_id: room_id.to_proto(),
1308                },
1309            )
1310            .trace_err();
1311    }
1312
1313    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1314    {
1315        live_kit
1316            .room_token(
1317                &joined_room.room.livekit_room,
1318                &session.user_id().to_string(),
1319            )
1320            .trace_err()
1321            .map(|token| proto::LiveKitConnectionInfo {
1322                server_url: live_kit.url().into(),
1323                token,
1324                can_publish: true,
1325            })
1326    } else {
1327        None
1328    };
1329
1330    response.send(proto::JoinRoomResponse {
1331        room: Some(joined_room.room),
1332        channel_id: None,
1333        live_kit_connection_info,
1334    })?;
1335
1336    update_user_contacts(session.user_id(), &session).await?;
1337    Ok(())
1338}
1339
1340/// Rejoin room is used to reconnect to a room after connection errors.
1341async fn rejoin_room(
1342    request: proto::RejoinRoom,
1343    response: Response<proto::RejoinRoom>,
1344    session: Session,
1345) -> Result<()> {
1346    let room;
1347    let channel;
1348    {
1349        let mut rejoined_room = session
1350            .db()
1351            .await
1352            .rejoin_room(request, session.user_id(), session.connection_id)
1353            .await?;
1354
1355        response.send(proto::RejoinRoomResponse {
1356            room: Some(rejoined_room.room.clone()),
1357            reshared_projects: rejoined_room
1358                .reshared_projects
1359                .iter()
1360                .map(|project| proto::ResharedProject {
1361                    id: project.id.to_proto(),
1362                    collaborators: project
1363                        .collaborators
1364                        .iter()
1365                        .map(|collaborator| collaborator.to_proto())
1366                        .collect(),
1367                })
1368                .collect(),
1369            rejoined_projects: rejoined_room
1370                .rejoined_projects
1371                .iter()
1372                .map(|rejoined_project| rejoined_project.to_proto())
1373                .collect(),
1374        })?;
1375        room_updated(&rejoined_room.room, &session.peer);
1376
1377        for project in &rejoined_room.reshared_projects {
1378            for collaborator in &project.collaborators {
1379                session
1380                    .peer
1381                    .send(
1382                        collaborator.connection_id,
1383                        proto::UpdateProjectCollaborator {
1384                            project_id: project.id.to_proto(),
1385                            old_peer_id: Some(project.old_connection_id.into()),
1386                            new_peer_id: Some(session.connection_id.into()),
1387                        },
1388                    )
1389                    .trace_err();
1390            }
1391
1392            broadcast(
1393                Some(session.connection_id),
1394                project
1395                    .collaborators
1396                    .iter()
1397                    .map(|collaborator| collaborator.connection_id),
1398                |connection_id| {
1399                    session.peer.forward_send(
1400                        session.connection_id,
1401                        connection_id,
1402                        proto::UpdateProject {
1403                            project_id: project.id.to_proto(),
1404                            worktrees: project.worktrees.clone(),
1405                        },
1406                    )
1407                },
1408            );
1409        }
1410
1411        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1412
1413        let rejoined_room = rejoined_room.into_inner();
1414
1415        room = rejoined_room.room;
1416        channel = rejoined_room.channel;
1417    }
1418
1419    if let Some(channel) = channel {
1420        channel_updated(
1421            &channel,
1422            &room,
1423            &session.peer,
1424            &*session.connection_pool().await,
1425        );
1426    }
1427
1428    update_user_contacts(session.user_id(), &session).await?;
1429    Ok(())
1430}
1431
1432fn notify_rejoined_projects(
1433    rejoined_projects: &mut Vec<RejoinedProject>,
1434    session: &Session,
1435) -> Result<()> {
1436    for project in rejoined_projects.iter() {
1437        for collaborator in &project.collaborators {
1438            session
1439                .peer
1440                .send(
1441                    collaborator.connection_id,
1442                    proto::UpdateProjectCollaborator {
1443                        project_id: project.id.to_proto(),
1444                        old_peer_id: Some(project.old_connection_id.into()),
1445                        new_peer_id: Some(session.connection_id.into()),
1446                    },
1447                )
1448                .trace_err();
1449        }
1450    }
1451
1452    for project in rejoined_projects {
1453        for worktree in mem::take(&mut project.worktrees) {
1454            // Stream this worktree's entries.
1455            let message = proto::UpdateWorktree {
1456                project_id: project.id.to_proto(),
1457                worktree_id: worktree.id,
1458                abs_path: worktree.abs_path.clone(),
1459                root_name: worktree.root_name,
1460                updated_entries: worktree.updated_entries,
1461                removed_entries: worktree.removed_entries,
1462                scan_id: worktree.scan_id,
1463                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1464                updated_repositories: worktree.updated_repositories,
1465                removed_repositories: worktree.removed_repositories,
1466            };
1467            for update in proto::split_worktree_update(message) {
1468                session.peer.send(session.connection_id, update)?;
1469            }
1470
1471            // Stream this worktree's diagnostics.
1472            for summary in worktree.diagnostic_summaries {
1473                session.peer.send(
1474                    session.connection_id,
1475                    proto::UpdateDiagnosticSummary {
1476                        project_id: project.id.to_proto(),
1477                        worktree_id: worktree.id,
1478                        summary: Some(summary),
1479                    },
1480                )?;
1481            }
1482
1483            for settings_file in worktree.settings_files {
1484                session.peer.send(
1485                    session.connection_id,
1486                    proto::UpdateWorktreeSettings {
1487                        project_id: project.id.to_proto(),
1488                        worktree_id: worktree.id,
1489                        path: settings_file.path,
1490                        content: Some(settings_file.content),
1491                        kind: Some(settings_file.kind.to_proto().into()),
1492                    },
1493                )?;
1494            }
1495        }
1496
1497        for repository in mem::take(&mut project.updated_repositories) {
1498            for update in split_repository_update(repository) {
1499                session.peer.send(session.connection_id, update)?;
1500            }
1501        }
1502
1503        for id in mem::take(&mut project.removed_repositories) {
1504            session.peer.send(
1505                session.connection_id,
1506                proto::RemoveRepository {
1507                    project_id: project.id.to_proto(),
1508                    id,
1509                },
1510            )?;
1511        }
1512    }
1513
1514    Ok(())
1515}
1516
1517/// leave room disconnects from the room.
1518async fn leave_room(
1519    _: proto::LeaveRoom,
1520    response: Response<proto::LeaveRoom>,
1521    session: Session,
1522) -> Result<()> {
1523    leave_room_for_session(&session, session.connection_id).await?;
1524    response.send(proto::Ack {})?;
1525    Ok(())
1526}
1527
1528/// Updates the permissions of someone else in the room.
1529async fn set_room_participant_role(
1530    request: proto::SetRoomParticipantRole,
1531    response: Response<proto::SetRoomParticipantRole>,
1532    session: Session,
1533) -> Result<()> {
1534    let user_id = UserId::from_proto(request.user_id);
1535    let role = ChannelRole::from(request.role());
1536
1537    let (livekit_room, can_publish) = {
1538        let room = session
1539            .db()
1540            .await
1541            .set_room_participant_role(
1542                session.user_id(),
1543                RoomId::from_proto(request.room_id),
1544                user_id,
1545                role,
1546            )
1547            .await?;
1548
1549        let livekit_room = room.livekit_room.clone();
1550        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1551        room_updated(&room, &session.peer);
1552        (livekit_room, can_publish)
1553    };
1554
1555    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1556        live_kit
1557            .update_participant(
1558                livekit_room.clone(),
1559                request.user_id.to_string(),
1560                livekit_api::proto::ParticipantPermission {
1561                    can_subscribe: true,
1562                    can_publish,
1563                    can_publish_data: can_publish,
1564                    hidden: false,
1565                    recorder: false,
1566                },
1567            )
1568            .await
1569            .trace_err();
1570    }
1571
1572    response.send(proto::Ack {})?;
1573    Ok(())
1574}
1575
1576/// Call someone else into the current room
1577async fn call(
1578    request: proto::Call,
1579    response: Response<proto::Call>,
1580    session: Session,
1581) -> Result<()> {
1582    let room_id = RoomId::from_proto(request.room_id);
1583    let calling_user_id = session.user_id();
1584    let calling_connection_id = session.connection_id;
1585    let called_user_id = UserId::from_proto(request.called_user_id);
1586    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1587    if !session
1588        .db()
1589        .await
1590        .has_contact(calling_user_id, called_user_id)
1591        .await?
1592    {
1593        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1594    }
1595
1596    let incoming_call = {
1597        let (room, incoming_call) = &mut *session
1598            .db()
1599            .await
1600            .call(
1601                room_id,
1602                calling_user_id,
1603                calling_connection_id,
1604                called_user_id,
1605                initial_project_id,
1606            )
1607            .await?;
1608        room_updated(room, &session.peer);
1609        mem::take(incoming_call)
1610    };
1611    update_user_contacts(called_user_id, &session).await?;
1612
1613    let mut calls = session
1614        .connection_pool()
1615        .await
1616        .user_connection_ids(called_user_id)
1617        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1618        .collect::<FuturesUnordered<_>>();
1619
1620    while let Some(call_response) = calls.next().await {
1621        match call_response.as_ref() {
1622            Ok(_) => {
1623                response.send(proto::Ack {})?;
1624                return Ok(());
1625            }
1626            Err(_) => {
1627                call_response.trace_err();
1628            }
1629        }
1630    }
1631
1632    {
1633        let room = session
1634            .db()
1635            .await
1636            .call_failed(room_id, called_user_id)
1637            .await?;
1638        room_updated(&room, &session.peer);
1639    }
1640    update_user_contacts(called_user_id, &session).await?;
1641
1642    Err(anyhow!("failed to ring user"))?
1643}
1644
1645/// Cancel an outgoing call.
1646async fn cancel_call(
1647    request: proto::CancelCall,
1648    response: Response<proto::CancelCall>,
1649    session: Session,
1650) -> Result<()> {
1651    let called_user_id = UserId::from_proto(request.called_user_id);
1652    let room_id = RoomId::from_proto(request.room_id);
1653    {
1654        let room = session
1655            .db()
1656            .await
1657            .cancel_call(room_id, session.connection_id, called_user_id)
1658            .await?;
1659        room_updated(&room, &session.peer);
1660    }
1661
1662    for connection_id in session
1663        .connection_pool()
1664        .await
1665        .user_connection_ids(called_user_id)
1666    {
1667        session
1668            .peer
1669            .send(
1670                connection_id,
1671                proto::CallCanceled {
1672                    room_id: room_id.to_proto(),
1673                },
1674            )
1675            .trace_err();
1676    }
1677    response.send(proto::Ack {})?;
1678
1679    update_user_contacts(called_user_id, &session).await?;
1680    Ok(())
1681}
1682
1683/// Decline an incoming call.
1684async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1685    let room_id = RoomId::from_proto(message.room_id);
1686    {
1687        let room = session
1688            .db()
1689            .await
1690            .decline_call(Some(room_id), session.user_id())
1691            .await?
1692            .ok_or_else(|| anyhow!("failed to decline call"))?;
1693        room_updated(&room, &session.peer);
1694    }
1695
1696    for connection_id in session
1697        .connection_pool()
1698        .await
1699        .user_connection_ids(session.user_id())
1700    {
1701        session
1702            .peer
1703            .send(
1704                connection_id,
1705                proto::CallCanceled {
1706                    room_id: room_id.to_proto(),
1707                },
1708            )
1709            .trace_err();
1710    }
1711    update_user_contacts(session.user_id(), &session).await?;
1712    Ok(())
1713}
1714
1715/// Updates other participants in the room with your current location.
1716async fn update_participant_location(
1717    request: proto::UpdateParticipantLocation,
1718    response: Response<proto::UpdateParticipantLocation>,
1719    session: Session,
1720) -> Result<()> {
1721    let room_id = RoomId::from_proto(request.room_id);
1722    let location = request
1723        .location
1724        .ok_or_else(|| anyhow!("invalid location"))?;
1725
1726    let db = session.db().await;
1727    let room = db
1728        .update_room_participant_location(room_id, session.connection_id, location)
1729        .await?;
1730
1731    room_updated(&room, &session.peer);
1732    response.send(proto::Ack {})?;
1733    Ok(())
1734}
1735
1736/// Share a project into the room.
1737async fn share_project(
1738    request: proto::ShareProject,
1739    response: Response<proto::ShareProject>,
1740    session: Session,
1741) -> Result<()> {
1742    let (project_id, room) = &*session
1743        .db()
1744        .await
1745        .share_project(
1746            RoomId::from_proto(request.room_id),
1747            session.connection_id,
1748            &request.worktrees,
1749            request.is_ssh_project,
1750        )
1751        .await?;
1752    response.send(proto::ShareProjectResponse {
1753        project_id: project_id.to_proto(),
1754    })?;
1755    room_updated(room, &session.peer);
1756
1757    Ok(())
1758}
1759
1760/// Unshare a project from the room.
1761async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1762    let project_id = ProjectId::from_proto(message.project_id);
1763    unshare_project_internal(project_id, session.connection_id, &session).await
1764}
1765
1766async fn unshare_project_internal(
1767    project_id: ProjectId,
1768    connection_id: ConnectionId,
1769    session: &Session,
1770) -> Result<()> {
1771    let delete = {
1772        let room_guard = session
1773            .db()
1774            .await
1775            .unshare_project(project_id, connection_id)
1776            .await?;
1777
1778        let (delete, room, guest_connection_ids) = &*room_guard;
1779
1780        let message = proto::UnshareProject {
1781            project_id: project_id.to_proto(),
1782        };
1783
1784        broadcast(
1785            Some(connection_id),
1786            guest_connection_ids.iter().copied(),
1787            |conn_id| session.peer.send(conn_id, message.clone()),
1788        );
1789        if let Some(room) = room {
1790            room_updated(room, &session.peer);
1791        }
1792
1793        *delete
1794    };
1795
1796    if delete {
1797        let db = session.db().await;
1798        db.delete_project(project_id).await?;
1799    }
1800
1801    Ok(())
1802}
1803
1804/// Join someone elses shared project.
1805async fn join_project(
1806    request: proto::JoinProject,
1807    response: Response<proto::JoinProject>,
1808    session: Session,
1809) -> Result<()> {
1810    let project_id = ProjectId::from_proto(request.project_id);
1811
1812    tracing::info!(%project_id, "join project");
1813
1814    let db = session.db().await;
1815    let (project, replica_id) = &mut *db
1816        .join_project(project_id, session.connection_id, session.user_id())
1817        .await?;
1818    drop(db);
1819    tracing::info!(%project_id, "join remote project");
1820    join_project_internal(response, session, project, replica_id)
1821}
1822
1823trait JoinProjectInternalResponse {
1824    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1827    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828        Response::<proto::JoinProject>::send(self, result)
1829    }
1830}
1831
1832fn join_project_internal(
1833    response: impl JoinProjectInternalResponse,
1834    session: Session,
1835    project: &mut Project,
1836    replica_id: &ReplicaId,
1837) -> Result<()> {
1838    let collaborators = project
1839        .collaborators
1840        .iter()
1841        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842        .map(|collaborator| collaborator.to_proto())
1843        .collect::<Vec<_>>();
1844    let project_id = project.id;
1845    let guest_user_id = session.user_id();
1846
1847    let worktrees = project
1848        .worktrees
1849        .iter()
1850        .map(|(id, worktree)| proto::WorktreeMetadata {
1851            id: *id,
1852            root_name: worktree.root_name.clone(),
1853            visible: worktree.visible,
1854            abs_path: worktree.abs_path.clone(),
1855        })
1856        .collect::<Vec<_>>();
1857
1858    let add_project_collaborator = proto::AddProjectCollaborator {
1859        project_id: project_id.to_proto(),
1860        collaborator: Some(proto::Collaborator {
1861            peer_id: Some(session.connection_id.into()),
1862            replica_id: replica_id.0 as u32,
1863            user_id: guest_user_id.to_proto(),
1864            is_host: false,
1865        }),
1866    };
1867
1868    for collaborator in &collaborators {
1869        session
1870            .peer
1871            .send(
1872                collaborator.peer_id.unwrap().into(),
1873                add_project_collaborator.clone(),
1874            )
1875            .trace_err();
1876    }
1877
1878    // First, we send the metadata associated with each worktree.
1879    response.send(proto::JoinProjectResponse {
1880        project_id: project.id.0 as u64,
1881        worktrees: worktrees.clone(),
1882        replica_id: replica_id.0 as u32,
1883        collaborators: collaborators.clone(),
1884        language_servers: project.language_servers.clone(),
1885        role: project.role.into(),
1886    })?;
1887
1888    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1889        // Stream this worktree's entries.
1890        let message = proto::UpdateWorktree {
1891            project_id: project_id.to_proto(),
1892            worktree_id,
1893            abs_path: worktree.abs_path.clone(),
1894            root_name: worktree.root_name,
1895            updated_entries: worktree.entries,
1896            removed_entries: Default::default(),
1897            scan_id: worktree.scan_id,
1898            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1899            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1900            removed_repositories: Default::default(),
1901        };
1902        for update in proto::split_worktree_update(message) {
1903            session.peer.send(session.connection_id, update.clone())?;
1904        }
1905
1906        // Stream this worktree's diagnostics.
1907        for summary in worktree.diagnostic_summaries {
1908            session.peer.send(
1909                session.connection_id,
1910                proto::UpdateDiagnosticSummary {
1911                    project_id: project_id.to_proto(),
1912                    worktree_id: worktree.id,
1913                    summary: Some(summary),
1914                },
1915            )?;
1916        }
1917
1918        for settings_file in worktree.settings_files {
1919            session.peer.send(
1920                session.connection_id,
1921                proto::UpdateWorktreeSettings {
1922                    project_id: project_id.to_proto(),
1923                    worktree_id: worktree.id,
1924                    path: settings_file.path,
1925                    content: Some(settings_file.content),
1926                    kind: Some(settings_file.kind.to_proto() as i32),
1927                },
1928            )?;
1929        }
1930    }
1931
1932    for repository in mem::take(&mut project.repositories) {
1933        for update in split_repository_update(repository) {
1934            session.peer.send(session.connection_id, update)?;
1935        }
1936    }
1937
1938    for language_server in &project.language_servers {
1939        session.peer.send(
1940            session.connection_id,
1941            proto::UpdateLanguageServer {
1942                project_id: project_id.to_proto(),
1943                language_server_id: language_server.id,
1944                variant: Some(
1945                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1946                        proto::LspDiskBasedDiagnosticsUpdated {},
1947                    ),
1948                ),
1949            },
1950        )?;
1951    }
1952
1953    Ok(())
1954}
1955
1956/// Leave someone elses shared project.
1957async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1958    let sender_id = session.connection_id;
1959    let project_id = ProjectId::from_proto(request.project_id);
1960    let db = session.db().await;
1961
1962    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963    tracing::info!(
1964        %project_id,
1965        "leave project"
1966    );
1967
1968    project_left(project, &session);
1969    if let Some(room) = room {
1970        room_updated(room, &session.peer);
1971    }
1972
1973    Ok(())
1974}
1975
1976/// Updates other participants with changes to the project
1977async fn update_project(
1978    request: proto::UpdateProject,
1979    response: Response<proto::UpdateProject>,
1980    session: Session,
1981) -> Result<()> {
1982    let project_id = ProjectId::from_proto(request.project_id);
1983    let (room, guest_connection_ids) = &*session
1984        .db()
1985        .await
1986        .update_project(project_id, session.connection_id, &request.worktrees)
1987        .await?;
1988    broadcast(
1989        Some(session.connection_id),
1990        guest_connection_ids.iter().copied(),
1991        |connection_id| {
1992            session
1993                .peer
1994                .forward_send(session.connection_id, connection_id, request.clone())
1995        },
1996    );
1997    if let Some(room) = room {
1998        room_updated(room, &session.peer);
1999    }
2000    response.send(proto::Ack {})?;
2001
2002    Ok(())
2003}
2004
2005/// Updates other participants with changes to the worktree
2006async fn update_worktree(
2007    request: proto::UpdateWorktree,
2008    response: Response<proto::UpdateWorktree>,
2009    session: Session,
2010) -> Result<()> {
2011    let guest_connection_ids = session
2012        .db()
2013        .await
2014        .update_worktree(&request, session.connection_id)
2015        .await?;
2016
2017    broadcast(
2018        Some(session.connection_id),
2019        guest_connection_ids.iter().copied(),
2020        |connection_id| {
2021            session
2022                .peer
2023                .forward_send(session.connection_id, connection_id, request.clone())
2024        },
2025    );
2026    response.send(proto::Ack {})?;
2027    Ok(())
2028}
2029
2030async fn update_repository(
2031    request: proto::UpdateRepository,
2032    response: Response<proto::UpdateRepository>,
2033    session: Session,
2034) -> Result<()> {
2035    let guest_connection_ids = session
2036        .db()
2037        .await
2038        .update_repository(&request, session.connection_id)
2039        .await?;
2040
2041    broadcast(
2042        Some(session.connection_id),
2043        guest_connection_ids.iter().copied(),
2044        |connection_id| {
2045            session
2046                .peer
2047                .forward_send(session.connection_id, connection_id, request.clone())
2048        },
2049    );
2050    response.send(proto::Ack {})?;
2051    Ok(())
2052}
2053
2054async fn remove_repository(
2055    request: proto::RemoveRepository,
2056    response: Response<proto::RemoveRepository>,
2057    session: Session,
2058) -> Result<()> {
2059    let guest_connection_ids = session
2060        .db()
2061        .await
2062        .remove_repository(&request, session.connection_id)
2063        .await?;
2064
2065    broadcast(
2066        Some(session.connection_id),
2067        guest_connection_ids.iter().copied(),
2068        |connection_id| {
2069            session
2070                .peer
2071                .forward_send(session.connection_id, connection_id, request.clone())
2072        },
2073    );
2074    response.send(proto::Ack {})?;
2075    Ok(())
2076}
2077
2078/// Updates other participants with changes to the diagnostics
2079async fn update_diagnostic_summary(
2080    message: proto::UpdateDiagnosticSummary,
2081    session: Session,
2082) -> Result<()> {
2083    let guest_connection_ids = session
2084        .db()
2085        .await
2086        .update_diagnostic_summary(&message, session.connection_id)
2087        .await?;
2088
2089    broadcast(
2090        Some(session.connection_id),
2091        guest_connection_ids.iter().copied(),
2092        |connection_id| {
2093            session
2094                .peer
2095                .forward_send(session.connection_id, connection_id, message.clone())
2096        },
2097    );
2098
2099    Ok(())
2100}
2101
2102/// Updates other participants with changes to the worktree settings
2103async fn update_worktree_settings(
2104    message: proto::UpdateWorktreeSettings,
2105    session: Session,
2106) -> Result<()> {
2107    let guest_connection_ids = session
2108        .db()
2109        .await
2110        .update_worktree_settings(&message, session.connection_id)
2111        .await?;
2112
2113    broadcast(
2114        Some(session.connection_id),
2115        guest_connection_ids.iter().copied(),
2116        |connection_id| {
2117            session
2118                .peer
2119                .forward_send(session.connection_id, connection_id, message.clone())
2120        },
2121    );
2122
2123    Ok(())
2124}
2125
2126/// Notify other participants that a language server has started.
2127async fn start_language_server(
2128    request: proto::StartLanguageServer,
2129    session: Session,
2130) -> Result<()> {
2131    let guest_connection_ids = session
2132        .db()
2133        .await
2134        .start_language_server(&request, session.connection_id)
2135        .await?;
2136
2137    broadcast(
2138        Some(session.connection_id),
2139        guest_connection_ids.iter().copied(),
2140        |connection_id| {
2141            session
2142                .peer
2143                .forward_send(session.connection_id, connection_id, request.clone())
2144        },
2145    );
2146    Ok(())
2147}
2148
2149/// Notify other participants that a language server has changed.
2150async fn update_language_server(
2151    request: proto::UpdateLanguageServer,
2152    session: Session,
2153) -> Result<()> {
2154    let project_id = ProjectId::from_proto(request.project_id);
2155    let project_connection_ids = session
2156        .db()
2157        .await
2158        .project_connection_ids(project_id, session.connection_id, true)
2159        .await?;
2160    broadcast(
2161        Some(session.connection_id),
2162        project_connection_ids.iter().copied(),
2163        |connection_id| {
2164            session
2165                .peer
2166                .forward_send(session.connection_id, connection_id, request.clone())
2167        },
2168    );
2169    Ok(())
2170}
2171
2172/// forward a project request to the host. These requests should be read only
2173/// as guests are allowed to send them.
2174async fn forward_read_only_project_request<T>(
2175    request: T,
2176    response: Response<T>,
2177    session: Session,
2178) -> Result<()>
2179where
2180    T: EntityMessage + RequestMessage,
2181{
2182    let project_id = ProjectId::from_proto(request.remote_entity_id());
2183    let host_connection_id = session
2184        .db()
2185        .await
2186        .host_for_read_only_project_request(project_id, session.connection_id)
2187        .await?;
2188    let payload = session
2189        .peer
2190        .forward_request(session.connection_id, host_connection_id, request)
2191        .await?;
2192    response.send(payload)?;
2193    Ok(())
2194}
2195
2196async fn forward_find_search_candidates_request(
2197    request: proto::FindSearchCandidates,
2198    response: Response<proto::FindSearchCandidates>,
2199    session: Session,
2200) -> Result<()> {
2201    let project_id = ProjectId::from_proto(request.remote_entity_id());
2202    let host_connection_id = session
2203        .db()
2204        .await
2205        .host_for_read_only_project_request(project_id, session.connection_id)
2206        .await?;
2207    let payload = session
2208        .peer
2209        .forward_request(session.connection_id, host_connection_id, request)
2210        .await?;
2211    response.send(payload)?;
2212    Ok(())
2213}
2214
2215/// forward a project request to the host. These requests are disallowed
2216/// for guests.
2217async fn forward_mutating_project_request<T>(
2218    request: T,
2219    response: Response<T>,
2220    session: Session,
2221) -> Result<()>
2222where
2223    T: EntityMessage + RequestMessage,
2224{
2225    let project_id = ProjectId::from_proto(request.remote_entity_id());
2226
2227    let host_connection_id = session
2228        .db()
2229        .await
2230        .host_for_mutating_project_request(project_id, session.connection_id)
2231        .await?;
2232    let payload = session
2233        .peer
2234        .forward_request(session.connection_id, host_connection_id, request)
2235        .await?;
2236    response.send(payload)?;
2237    Ok(())
2238}
2239
2240/// Notify other participants that a new buffer has been created
2241async fn create_buffer_for_peer(
2242    request: proto::CreateBufferForPeer,
2243    session: Session,
2244) -> Result<()> {
2245    session
2246        .db()
2247        .await
2248        .check_user_is_project_host(
2249            ProjectId::from_proto(request.project_id),
2250            session.connection_id,
2251        )
2252        .await?;
2253    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2254    session
2255        .peer
2256        .forward_send(session.connection_id, peer_id.into(), request)?;
2257    Ok(())
2258}
2259
2260/// Notify other participants that a buffer has been updated. This is
2261/// allowed for guests as long as the update is limited to selections.
2262async fn update_buffer(
2263    request: proto::UpdateBuffer,
2264    response: Response<proto::UpdateBuffer>,
2265    session: Session,
2266) -> Result<()> {
2267    let project_id = ProjectId::from_proto(request.project_id);
2268    let mut capability = Capability::ReadOnly;
2269
2270    for op in request.operations.iter() {
2271        match op.variant {
2272            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2273            Some(_) => capability = Capability::ReadWrite,
2274        }
2275    }
2276
2277    let host = {
2278        let guard = session
2279            .db()
2280            .await
2281            .connections_for_buffer_update(project_id, session.connection_id, capability)
2282            .await?;
2283
2284        let (host, guests) = &*guard;
2285
2286        broadcast(
2287            Some(session.connection_id),
2288            guests.clone(),
2289            |connection_id| {
2290                session
2291                    .peer
2292                    .forward_send(session.connection_id, connection_id, request.clone())
2293            },
2294        );
2295
2296        *host
2297    };
2298
2299    if host != session.connection_id {
2300        session
2301            .peer
2302            .forward_request(session.connection_id, host, request.clone())
2303            .await?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2311    let project_id = ProjectId::from_proto(message.project_id);
2312
2313    let operation = message.operation.as_ref().context("invalid operation")?;
2314    let capability = match operation.variant.as_ref() {
2315        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2316            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2317                match buffer_op.variant {
2318                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2319                        Capability::ReadOnly
2320                    }
2321                    _ => Capability::ReadWrite,
2322                }
2323            } else {
2324                Capability::ReadWrite
2325            }
2326        }
2327        Some(_) => Capability::ReadWrite,
2328        None => Capability::ReadOnly,
2329    };
2330
2331    let guard = session
2332        .db()
2333        .await
2334        .connections_for_buffer_update(project_id, session.connection_id, capability)
2335        .await?;
2336
2337    let (host, guests) = &*guard;
2338
2339    broadcast(
2340        Some(session.connection_id),
2341        guests.iter().chain([host]).copied(),
2342        |connection_id| {
2343            session
2344                .peer
2345                .forward_send(session.connection_id, connection_id, message.clone())
2346        },
2347    );
2348
2349    Ok(())
2350}
2351
2352/// Notify other participants that a project has been updated.
2353async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2354    request: T,
2355    session: Session,
2356) -> Result<()> {
2357    let project_id = ProjectId::from_proto(request.remote_entity_id());
2358    let project_connection_ids = session
2359        .db()
2360        .await
2361        .project_connection_ids(project_id, session.connection_id, false)
2362        .await?;
2363
2364    broadcast(
2365        Some(session.connection_id),
2366        project_connection_ids.iter().copied(),
2367        |connection_id| {
2368            session
2369                .peer
2370                .forward_send(session.connection_id, connection_id, request.clone())
2371        },
2372    );
2373    Ok(())
2374}
2375
2376/// Start following another user in a call.
2377async fn follow(
2378    request: proto::Follow,
2379    response: Response<proto::Follow>,
2380    session: Session,
2381) -> Result<()> {
2382    let room_id = RoomId::from_proto(request.room_id);
2383    let project_id = request.project_id.map(ProjectId::from_proto);
2384    let leader_id = request
2385        .leader_id
2386        .ok_or_else(|| anyhow!("invalid leader id"))?
2387        .into();
2388    let follower_id = session.connection_id;
2389
2390    session
2391        .db()
2392        .await
2393        .check_room_participants(room_id, leader_id, session.connection_id)
2394        .await?;
2395
2396    let response_payload = session
2397        .peer
2398        .forward_request(session.connection_id, leader_id, request)
2399        .await?;
2400    response.send(response_payload)?;
2401
2402    if let Some(project_id) = project_id {
2403        let room = session
2404            .db()
2405            .await
2406            .follow(room_id, project_id, leader_id, follower_id)
2407            .await?;
2408        room_updated(&room, &session.peer);
2409    }
2410
2411    Ok(())
2412}
2413
2414/// Stop following another user in a call.
2415async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2416    let room_id = RoomId::from_proto(request.room_id);
2417    let project_id = request.project_id.map(ProjectId::from_proto);
2418    let leader_id = request
2419        .leader_id
2420        .ok_or_else(|| anyhow!("invalid leader id"))?
2421        .into();
2422    let follower_id = session.connection_id;
2423
2424    session
2425        .db()
2426        .await
2427        .check_room_participants(room_id, leader_id, session.connection_id)
2428        .await?;
2429
2430    session
2431        .peer
2432        .forward_send(session.connection_id, leader_id, request)?;
2433
2434    if let Some(project_id) = project_id {
2435        let room = session
2436            .db()
2437            .await
2438            .unfollow(room_id, project_id, leader_id, follower_id)
2439            .await?;
2440        room_updated(&room, &session.peer);
2441    }
2442
2443    Ok(())
2444}
2445
2446/// Notify everyone following you of your current location.
2447async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2448    let room_id = RoomId::from_proto(request.room_id);
2449    let database = session.db.lock().await;
2450
2451    let connection_ids = if let Some(project_id) = request.project_id {
2452        let project_id = ProjectId::from_proto(project_id);
2453        database
2454            .project_connection_ids(project_id, session.connection_id, true)
2455            .await?
2456    } else {
2457        database
2458            .room_connection_ids(room_id, session.connection_id)
2459            .await?
2460    };
2461
2462    // For now, don't send view update messages back to that view's current leader.
2463    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2464        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2465        _ => None,
2466    });
2467
2468    for connection_id in connection_ids.iter().cloned() {
2469        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2470            session
2471                .peer
2472                .forward_send(session.connection_id, connection_id, request.clone())?;
2473        }
2474    }
2475    Ok(())
2476}
2477
2478/// Get public data about users.
2479async fn get_users(
2480    request: proto::GetUsers,
2481    response: Response<proto::GetUsers>,
2482    session: Session,
2483) -> Result<()> {
2484    let user_ids = request
2485        .user_ids
2486        .into_iter()
2487        .map(UserId::from_proto)
2488        .collect();
2489    let users = session
2490        .db()
2491        .await
2492        .get_users_by_ids(user_ids)
2493        .await?
2494        .into_iter()
2495        .map(|user| proto::User {
2496            id: user.id.to_proto(),
2497            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2498            github_login: user.github_login,
2499            email: user.email_address,
2500            name: user.name,
2501        })
2502        .collect();
2503    response.send(proto::UsersResponse { users })?;
2504    Ok(())
2505}
2506
2507/// Search for users (to invite) buy Github login
2508async fn fuzzy_search_users(
2509    request: proto::FuzzySearchUsers,
2510    response: Response<proto::FuzzySearchUsers>,
2511    session: Session,
2512) -> Result<()> {
2513    let query = request.query;
2514    let users = match query.len() {
2515        0 => vec![],
2516        1 | 2 => session
2517            .db()
2518            .await
2519            .get_user_by_github_login(&query)
2520            .await?
2521            .into_iter()
2522            .collect(),
2523        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2524    };
2525    let users = users
2526        .into_iter()
2527        .filter(|user| user.id != session.user_id())
2528        .map(|user| proto::User {
2529            id: user.id.to_proto(),
2530            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2531            github_login: user.github_login,
2532            name: user.name,
2533            email: user.email_address,
2534        })
2535        .collect();
2536    response.send(proto::UsersResponse { users })?;
2537    Ok(())
2538}
2539
2540/// Send a contact request to another user.
2541async fn request_contact(
2542    request: proto::RequestContact,
2543    response: Response<proto::RequestContact>,
2544    session: Session,
2545) -> Result<()> {
2546    let requester_id = session.user_id();
2547    let responder_id = UserId::from_proto(request.responder_id);
2548    if requester_id == responder_id {
2549        return Err(anyhow!("cannot add yourself as a contact"))?;
2550    }
2551
2552    let notifications = session
2553        .db()
2554        .await
2555        .send_contact_request(requester_id, responder_id)
2556        .await?;
2557
2558    // Update outgoing contact requests of requester
2559    let mut update = proto::UpdateContacts::default();
2560    update.outgoing_requests.push(responder_id.to_proto());
2561    for connection_id in session
2562        .connection_pool()
2563        .await
2564        .user_connection_ids(requester_id)
2565    {
2566        session.peer.send(connection_id, update.clone())?;
2567    }
2568
2569    // Update incoming contact requests of responder
2570    let mut update = proto::UpdateContacts::default();
2571    update
2572        .incoming_requests
2573        .push(proto::IncomingContactRequest {
2574            requester_id: requester_id.to_proto(),
2575        });
2576    let connection_pool = session.connection_pool().await;
2577    for connection_id in connection_pool.user_connection_ids(responder_id) {
2578        session.peer.send(connection_id, update.clone())?;
2579    }
2580
2581    send_notifications(&connection_pool, &session.peer, notifications);
2582
2583    response.send(proto::Ack {})?;
2584    Ok(())
2585}
2586
2587/// Accept or decline a contact request
2588async fn respond_to_contact_request(
2589    request: proto::RespondToContactRequest,
2590    response: Response<proto::RespondToContactRequest>,
2591    session: Session,
2592) -> Result<()> {
2593    let responder_id = session.user_id();
2594    let requester_id = UserId::from_proto(request.requester_id);
2595    let db = session.db().await;
2596    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2597        db.dismiss_contact_notification(responder_id, requester_id)
2598            .await?;
2599    } else {
2600        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2601
2602        let notifications = db
2603            .respond_to_contact_request(responder_id, requester_id, accept)
2604            .await?;
2605        let requester_busy = db.is_user_busy(requester_id).await?;
2606        let responder_busy = db.is_user_busy(responder_id).await?;
2607
2608        let pool = session.connection_pool().await;
2609        // Update responder with new contact
2610        let mut update = proto::UpdateContacts::default();
2611        if accept {
2612            update
2613                .contacts
2614                .push(contact_for_user(requester_id, requester_busy, &pool));
2615        }
2616        update
2617            .remove_incoming_requests
2618            .push(requester_id.to_proto());
2619        for connection_id in pool.user_connection_ids(responder_id) {
2620            session.peer.send(connection_id, update.clone())?;
2621        }
2622
2623        // Update requester with new contact
2624        let mut update = proto::UpdateContacts::default();
2625        if accept {
2626            update
2627                .contacts
2628                .push(contact_for_user(responder_id, responder_busy, &pool));
2629        }
2630        update
2631            .remove_outgoing_requests
2632            .push(responder_id.to_proto());
2633
2634        for connection_id in pool.user_connection_ids(requester_id) {
2635            session.peer.send(connection_id, update.clone())?;
2636        }
2637
2638        send_notifications(&pool, &session.peer, notifications);
2639    }
2640
2641    response.send(proto::Ack {})?;
2642    Ok(())
2643}
2644
2645/// Remove a contact.
2646async fn remove_contact(
2647    request: proto::RemoveContact,
2648    response: Response<proto::RemoveContact>,
2649    session: Session,
2650) -> Result<()> {
2651    let requester_id = session.user_id();
2652    let responder_id = UserId::from_proto(request.user_id);
2653    let db = session.db().await;
2654    let (contact_accepted, deleted_notification_id) =
2655        db.remove_contact(requester_id, responder_id).await?;
2656
2657    let pool = session.connection_pool().await;
2658    // Update outgoing contact requests of requester
2659    let mut update = proto::UpdateContacts::default();
2660    if contact_accepted {
2661        update.remove_contacts.push(responder_id.to_proto());
2662    } else {
2663        update
2664            .remove_outgoing_requests
2665            .push(responder_id.to_proto());
2666    }
2667    for connection_id in pool.user_connection_ids(requester_id) {
2668        session.peer.send(connection_id, update.clone())?;
2669    }
2670
2671    // Update incoming contact requests of responder
2672    let mut update = proto::UpdateContacts::default();
2673    if contact_accepted {
2674        update.remove_contacts.push(requester_id.to_proto());
2675    } else {
2676        update
2677            .remove_incoming_requests
2678            .push(requester_id.to_proto());
2679    }
2680    for connection_id in pool.user_connection_ids(responder_id) {
2681        session.peer.send(connection_id, update.clone())?;
2682        if let Some(notification_id) = deleted_notification_id {
2683            session.peer.send(
2684                connection_id,
2685                proto::DeleteNotification {
2686                    notification_id: notification_id.to_proto(),
2687                },
2688            )?;
2689        }
2690    }
2691
2692    response.send(proto::Ack {})?;
2693    Ok(())
2694}
2695
2696fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2697    version.0.minor() < 139
2698}
2699
2700async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2701    let plan = session.current_plan(&session.db().await).await?;
2702
2703    session
2704        .peer
2705        .send(
2706            session.connection_id,
2707            proto::UpdateUserPlan { plan: plan.into() },
2708        )
2709        .trace_err();
2710
2711    Ok(())
2712}
2713
2714async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2715    subscribe_user_to_channels(session.user_id(), &session).await?;
2716    Ok(())
2717}
2718
2719async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2720    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2721    let mut pool = session.connection_pool().await;
2722    for membership in &channels_for_user.channel_memberships {
2723        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2724    }
2725    session.peer.send(
2726        session.connection_id,
2727        build_update_user_channels(&channels_for_user),
2728    )?;
2729    session.peer.send(
2730        session.connection_id,
2731        build_channels_update(channels_for_user),
2732    )?;
2733    Ok(())
2734}
2735
2736/// Creates a new channel.
2737async fn create_channel(
2738    request: proto::CreateChannel,
2739    response: Response<proto::CreateChannel>,
2740    session: Session,
2741) -> Result<()> {
2742    let db = session.db().await;
2743
2744    let parent_id = request.parent_id.map(ChannelId::from_proto);
2745    let (channel, membership) = db
2746        .create_channel(&request.name, parent_id, session.user_id())
2747        .await?;
2748
2749    let root_id = channel.root_id();
2750    let channel = Channel::from_model(channel);
2751
2752    response.send(proto::CreateChannelResponse {
2753        channel: Some(channel.to_proto()),
2754        parent_id: request.parent_id,
2755    })?;
2756
2757    let mut connection_pool = session.connection_pool().await;
2758    if let Some(membership) = membership {
2759        connection_pool.subscribe_to_channel(
2760            membership.user_id,
2761            membership.channel_id,
2762            membership.role,
2763        );
2764        let update = proto::UpdateUserChannels {
2765            channel_memberships: vec![proto::ChannelMembership {
2766                channel_id: membership.channel_id.to_proto(),
2767                role: membership.role.into(),
2768            }],
2769            ..Default::default()
2770        };
2771        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2772            session.peer.send(connection_id, update.clone())?;
2773        }
2774    }
2775
2776    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2777        if !role.can_see_channel(channel.visibility) {
2778            continue;
2779        }
2780
2781        let update = proto::UpdateChannels {
2782            channels: vec![channel.to_proto()],
2783            ..Default::default()
2784        };
2785        session.peer.send(connection_id, update.clone())?;
2786    }
2787
2788    Ok(())
2789}
2790
2791/// Delete a channel
2792async fn delete_channel(
2793    request: proto::DeleteChannel,
2794    response: Response<proto::DeleteChannel>,
2795    session: Session,
2796) -> Result<()> {
2797    let db = session.db().await;
2798
2799    let channel_id = request.channel_id;
2800    let (root_channel, removed_channels) = db
2801        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2802        .await?;
2803    response.send(proto::Ack {})?;
2804
2805    // Notify members of removed channels
2806    let mut update = proto::UpdateChannels::default();
2807    update
2808        .delete_channels
2809        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2810
2811    let connection_pool = session.connection_pool().await;
2812    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2813        session.peer.send(connection_id, update.clone())?;
2814    }
2815
2816    Ok(())
2817}
2818
2819/// Invite someone to join a channel.
2820async fn invite_channel_member(
2821    request: proto::InviteChannelMember,
2822    response: Response<proto::InviteChannelMember>,
2823    session: Session,
2824) -> Result<()> {
2825    let db = session.db().await;
2826    let channel_id = ChannelId::from_proto(request.channel_id);
2827    let invitee_id = UserId::from_proto(request.user_id);
2828    let InviteMemberResult {
2829        channel,
2830        notifications,
2831    } = db
2832        .invite_channel_member(
2833            channel_id,
2834            invitee_id,
2835            session.user_id(),
2836            request.role().into(),
2837        )
2838        .await?;
2839
2840    let update = proto::UpdateChannels {
2841        channel_invitations: vec![channel.to_proto()],
2842        ..Default::default()
2843    };
2844
2845    let connection_pool = session.connection_pool().await;
2846    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2847        session.peer.send(connection_id, update.clone())?;
2848    }
2849
2850    send_notifications(&connection_pool, &session.peer, notifications);
2851
2852    response.send(proto::Ack {})?;
2853    Ok(())
2854}
2855
2856/// remove someone from a channel
2857async fn remove_channel_member(
2858    request: proto::RemoveChannelMember,
2859    response: Response<proto::RemoveChannelMember>,
2860    session: Session,
2861) -> Result<()> {
2862    let db = session.db().await;
2863    let channel_id = ChannelId::from_proto(request.channel_id);
2864    let member_id = UserId::from_proto(request.user_id);
2865
2866    let RemoveChannelMemberResult {
2867        membership_update,
2868        notification_id,
2869    } = db
2870        .remove_channel_member(channel_id, member_id, session.user_id())
2871        .await?;
2872
2873    let mut connection_pool = session.connection_pool().await;
2874    notify_membership_updated(
2875        &mut connection_pool,
2876        membership_update,
2877        member_id,
2878        &session.peer,
2879    );
2880    for connection_id in connection_pool.user_connection_ids(member_id) {
2881        if let Some(notification_id) = notification_id {
2882            session
2883                .peer
2884                .send(
2885                    connection_id,
2886                    proto::DeleteNotification {
2887                        notification_id: notification_id.to_proto(),
2888                    },
2889                )
2890                .trace_err();
2891        }
2892    }
2893
2894    response.send(proto::Ack {})?;
2895    Ok(())
2896}
2897
2898/// Toggle the channel between public and private.
2899/// Care is taken to maintain the invariant that public channels only descend from public channels,
2900/// (though members-only channels can appear at any point in the hierarchy).
2901async fn set_channel_visibility(
2902    request: proto::SetChannelVisibility,
2903    response: Response<proto::SetChannelVisibility>,
2904    session: Session,
2905) -> Result<()> {
2906    let db = session.db().await;
2907    let channel_id = ChannelId::from_proto(request.channel_id);
2908    let visibility = request.visibility().into();
2909
2910    let channel_model = db
2911        .set_channel_visibility(channel_id, visibility, session.user_id())
2912        .await?;
2913    let root_id = channel_model.root_id();
2914    let channel = Channel::from_model(channel_model);
2915
2916    let mut connection_pool = session.connection_pool().await;
2917    for (user_id, role) in connection_pool
2918        .channel_user_ids(root_id)
2919        .collect::<Vec<_>>()
2920        .into_iter()
2921    {
2922        let update = if role.can_see_channel(channel.visibility) {
2923            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2924            proto::UpdateChannels {
2925                channels: vec![channel.to_proto()],
2926                ..Default::default()
2927            }
2928        } else {
2929            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2930            proto::UpdateChannels {
2931                delete_channels: vec![channel.id.to_proto()],
2932                ..Default::default()
2933            }
2934        };
2935
2936        for connection_id in connection_pool.user_connection_ids(user_id) {
2937            session.peer.send(connection_id, update.clone())?;
2938        }
2939    }
2940
2941    response.send(proto::Ack {})?;
2942    Ok(())
2943}
2944
2945/// Alter the role for a user in the channel.
2946async fn set_channel_member_role(
2947    request: proto::SetChannelMemberRole,
2948    response: Response<proto::SetChannelMemberRole>,
2949    session: Session,
2950) -> Result<()> {
2951    let db = session.db().await;
2952    let channel_id = ChannelId::from_proto(request.channel_id);
2953    let member_id = UserId::from_proto(request.user_id);
2954    let result = db
2955        .set_channel_member_role(
2956            channel_id,
2957            session.user_id(),
2958            member_id,
2959            request.role().into(),
2960        )
2961        .await?;
2962
2963    match result {
2964        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2965            let mut connection_pool = session.connection_pool().await;
2966            notify_membership_updated(
2967                &mut connection_pool,
2968                membership_update,
2969                member_id,
2970                &session.peer,
2971            )
2972        }
2973        db::SetMemberRoleResult::InviteUpdated(channel) => {
2974            let update = proto::UpdateChannels {
2975                channel_invitations: vec![channel.to_proto()],
2976                ..Default::default()
2977            };
2978
2979            for connection_id in session
2980                .connection_pool()
2981                .await
2982                .user_connection_ids(member_id)
2983            {
2984                session.peer.send(connection_id, update.clone())?;
2985            }
2986        }
2987    }
2988
2989    response.send(proto::Ack {})?;
2990    Ok(())
2991}
2992
2993/// Change the name of a channel
2994async fn rename_channel(
2995    request: proto::RenameChannel,
2996    response: Response<proto::RenameChannel>,
2997    session: Session,
2998) -> Result<()> {
2999    let db = session.db().await;
3000    let channel_id = ChannelId::from_proto(request.channel_id);
3001    let channel_model = db
3002        .rename_channel(channel_id, session.user_id(), &request.name)
3003        .await?;
3004    let root_id = channel_model.root_id();
3005    let channel = Channel::from_model(channel_model);
3006
3007    response.send(proto::RenameChannelResponse {
3008        channel: Some(channel.to_proto()),
3009    })?;
3010
3011    let connection_pool = session.connection_pool().await;
3012    let update = proto::UpdateChannels {
3013        channels: vec![channel.to_proto()],
3014        ..Default::default()
3015    };
3016    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3017        if role.can_see_channel(channel.visibility) {
3018            session.peer.send(connection_id, update.clone())?;
3019        }
3020    }
3021
3022    Ok(())
3023}
3024
3025/// Move a channel to a new parent.
3026async fn move_channel(
3027    request: proto::MoveChannel,
3028    response: Response<proto::MoveChannel>,
3029    session: Session,
3030) -> Result<()> {
3031    let channel_id = ChannelId::from_proto(request.channel_id);
3032    let to = ChannelId::from_proto(request.to);
3033
3034    let (root_id, channels) = session
3035        .db()
3036        .await
3037        .move_channel(channel_id, to, session.user_id())
3038        .await?;
3039
3040    let connection_pool = session.connection_pool().await;
3041    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3042        let channels = channels
3043            .iter()
3044            .filter_map(|channel| {
3045                if role.can_see_channel(channel.visibility) {
3046                    Some(channel.to_proto())
3047                } else {
3048                    None
3049                }
3050            })
3051            .collect::<Vec<_>>();
3052        if channels.is_empty() {
3053            continue;
3054        }
3055
3056        let update = proto::UpdateChannels {
3057            channels,
3058            ..Default::default()
3059        };
3060
3061        session.peer.send(connection_id, update.clone())?;
3062    }
3063
3064    response.send(Ack {})?;
3065    Ok(())
3066}
3067
3068/// Get the list of channel members
3069async fn get_channel_members(
3070    request: proto::GetChannelMembers,
3071    response: Response<proto::GetChannelMembers>,
3072    session: Session,
3073) -> Result<()> {
3074    let db = session.db().await;
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    let limit = if request.limit == 0 {
3077        u16::MAX as u64
3078    } else {
3079        request.limit
3080    };
3081    let (members, users) = db
3082        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3083        .await?;
3084    response.send(proto::GetChannelMembersResponse { members, users })?;
3085    Ok(())
3086}
3087
3088/// Accept or decline a channel invitation.
3089async fn respond_to_channel_invite(
3090    request: proto::RespondToChannelInvite,
3091    response: Response<proto::RespondToChannelInvite>,
3092    session: Session,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096    let RespondToChannelInvite {
3097        membership_update,
3098        notifications,
3099    } = db
3100        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3101        .await?;
3102
3103    let mut connection_pool = session.connection_pool().await;
3104    if let Some(membership_update) = membership_update {
3105        notify_membership_updated(
3106            &mut connection_pool,
3107            membership_update,
3108            session.user_id(),
3109            &session.peer,
3110        );
3111    } else {
3112        let update = proto::UpdateChannels {
3113            remove_channel_invitations: vec![channel_id.to_proto()],
3114            ..Default::default()
3115        };
3116
3117        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3118            session.peer.send(connection_id, update.clone())?;
3119        }
3120    };
3121
3122    send_notifications(&connection_pool, &session.peer, notifications);
3123
3124    response.send(proto::Ack {})?;
3125
3126    Ok(())
3127}
3128
3129/// Join the channels' room
3130async fn join_channel(
3131    request: proto::JoinChannel,
3132    response: Response<proto::JoinChannel>,
3133    session: Session,
3134) -> Result<()> {
3135    let channel_id = ChannelId::from_proto(request.channel_id);
3136    join_channel_internal(channel_id, Box::new(response), session).await
3137}
3138
3139trait JoinChannelInternalResponse {
3140    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3141}
3142impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3143    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3144        Response::<proto::JoinChannel>::send(self, result)
3145    }
3146}
3147impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3148    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3149        Response::<proto::JoinRoom>::send(self, result)
3150    }
3151}
3152
3153async fn join_channel_internal(
3154    channel_id: ChannelId,
3155    response: Box<impl JoinChannelInternalResponse>,
3156    session: Session,
3157) -> Result<()> {
3158    let joined_room = {
3159        let mut db = session.db().await;
3160        // If zed quits without leaving the room, and the user re-opens zed before the
3161        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3162        // room they were in.
3163        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3164            tracing::info!(
3165                stale_connection_id = %connection,
3166                "cleaning up stale connection",
3167            );
3168            drop(db);
3169            leave_room_for_session(&session, connection).await?;
3170            db = session.db().await;
3171        }
3172
3173        let (joined_room, membership_updated, role) = db
3174            .join_channel(channel_id, session.user_id(), session.connection_id)
3175            .await?;
3176
3177        let live_kit_connection_info =
3178            session
3179                .app_state
3180                .livekit_client
3181                .as_ref()
3182                .and_then(|live_kit| {
3183                    let (can_publish, token) = if role == ChannelRole::Guest {
3184                        (
3185                            false,
3186                            live_kit
3187                                .guest_token(
3188                                    &joined_room.room.livekit_room,
3189                                    &session.user_id().to_string(),
3190                                )
3191                                .trace_err()?,
3192                        )
3193                    } else {
3194                        (
3195                            true,
3196                            live_kit
3197                                .room_token(
3198                                    &joined_room.room.livekit_room,
3199                                    &session.user_id().to_string(),
3200                                )
3201                                .trace_err()?,
3202                        )
3203                    };
3204
3205                    Some(LiveKitConnectionInfo {
3206                        server_url: live_kit.url().into(),
3207                        token,
3208                        can_publish,
3209                    })
3210                });
3211
3212        response.send(proto::JoinRoomResponse {
3213            room: Some(joined_room.room.clone()),
3214            channel_id: joined_room
3215                .channel
3216                .as_ref()
3217                .map(|channel| channel.id.to_proto()),
3218            live_kit_connection_info,
3219        })?;
3220
3221        let mut connection_pool = session.connection_pool().await;
3222        if let Some(membership_updated) = membership_updated {
3223            notify_membership_updated(
3224                &mut connection_pool,
3225                membership_updated,
3226                session.user_id(),
3227                &session.peer,
3228            );
3229        }
3230
3231        room_updated(&joined_room.room, &session.peer);
3232
3233        joined_room
3234    };
3235
3236    channel_updated(
3237        &joined_room
3238            .channel
3239            .ok_or_else(|| anyhow!("channel not returned"))?,
3240        &joined_room.room,
3241        &session.peer,
3242        &*session.connection_pool().await,
3243    );
3244
3245    update_user_contacts(session.user_id(), &session).await?;
3246    Ok(())
3247}
3248
3249/// Start editing the channel notes
3250async fn join_channel_buffer(
3251    request: proto::JoinChannelBuffer,
3252    response: Response<proto::JoinChannelBuffer>,
3253    session: Session,
3254) -> Result<()> {
3255    let db = session.db().await;
3256    let channel_id = ChannelId::from_proto(request.channel_id);
3257
3258    let open_response = db
3259        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3260        .await?;
3261
3262    let collaborators = open_response.collaborators.clone();
3263    response.send(open_response)?;
3264
3265    let update = UpdateChannelBufferCollaborators {
3266        channel_id: channel_id.to_proto(),
3267        collaborators: collaborators.clone(),
3268    };
3269    channel_buffer_updated(
3270        session.connection_id,
3271        collaborators
3272            .iter()
3273            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3274        &update,
3275        &session.peer,
3276    );
3277
3278    Ok(())
3279}
3280
3281/// Edit the channel notes
3282async fn update_channel_buffer(
3283    request: proto::UpdateChannelBuffer,
3284    session: Session,
3285) -> Result<()> {
3286    let db = session.db().await;
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288
3289    let (collaborators, epoch, version) = db
3290        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3291        .await?;
3292
3293    channel_buffer_updated(
3294        session.connection_id,
3295        collaborators.clone(),
3296        &proto::UpdateChannelBuffer {
3297            channel_id: channel_id.to_proto(),
3298            operations: request.operations,
3299        },
3300        &session.peer,
3301    );
3302
3303    let pool = &*session.connection_pool().await;
3304
3305    let non_collaborators =
3306        pool.channel_connection_ids(channel_id)
3307            .filter_map(|(connection_id, _)| {
3308                if collaborators.contains(&connection_id) {
3309                    None
3310                } else {
3311                    Some(connection_id)
3312                }
3313            });
3314
3315    broadcast(None, non_collaborators, |peer_id| {
3316        session.peer.send(
3317            peer_id,
3318            proto::UpdateChannels {
3319                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3320                    channel_id: channel_id.to_proto(),
3321                    epoch: epoch as u64,
3322                    version: version.clone(),
3323                }],
3324                ..Default::default()
3325            },
3326        )
3327    });
3328
3329    Ok(())
3330}
3331
3332/// Rejoin the channel notes after a connection blip
3333async fn rejoin_channel_buffers(
3334    request: proto::RejoinChannelBuffers,
3335    response: Response<proto::RejoinChannelBuffers>,
3336    session: Session,
3337) -> Result<()> {
3338    let db = session.db().await;
3339    let buffers = db
3340        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3341        .await?;
3342
3343    for rejoined_buffer in &buffers {
3344        let collaborators_to_notify = rejoined_buffer
3345            .buffer
3346            .collaborators
3347            .iter()
3348            .filter_map(|c| Some(c.peer_id?.into()));
3349        channel_buffer_updated(
3350            session.connection_id,
3351            collaborators_to_notify,
3352            &proto::UpdateChannelBufferCollaborators {
3353                channel_id: rejoined_buffer.buffer.channel_id,
3354                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3355            },
3356            &session.peer,
3357        );
3358    }
3359
3360    response.send(proto::RejoinChannelBuffersResponse {
3361        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3362    })?;
3363
3364    Ok(())
3365}
3366
3367/// Stop editing the channel notes
3368async fn leave_channel_buffer(
3369    request: proto::LeaveChannelBuffer,
3370    response: Response<proto::LeaveChannelBuffer>,
3371    session: Session,
3372) -> Result<()> {
3373    let db = session.db().await;
3374    let channel_id = ChannelId::from_proto(request.channel_id);
3375
3376    let left_buffer = db
3377        .leave_channel_buffer(channel_id, session.connection_id)
3378        .await?;
3379
3380    response.send(Ack {})?;
3381
3382    channel_buffer_updated(
3383        session.connection_id,
3384        left_buffer.connections,
3385        &proto::UpdateChannelBufferCollaborators {
3386            channel_id: channel_id.to_proto(),
3387            collaborators: left_buffer.collaborators,
3388        },
3389        &session.peer,
3390    );
3391
3392    Ok(())
3393}
3394
3395fn channel_buffer_updated<T: EnvelopedMessage>(
3396    sender_id: ConnectionId,
3397    collaborators: impl IntoIterator<Item = ConnectionId>,
3398    message: &T,
3399    peer: &Peer,
3400) {
3401    broadcast(Some(sender_id), collaborators, |peer_id| {
3402        peer.send(peer_id, message.clone())
3403    });
3404}
3405
3406fn send_notifications(
3407    connection_pool: &ConnectionPool,
3408    peer: &Peer,
3409    notifications: db::NotificationBatch,
3410) {
3411    for (user_id, notification) in notifications {
3412        for connection_id in connection_pool.user_connection_ids(user_id) {
3413            if let Err(error) = peer.send(
3414                connection_id,
3415                proto::AddNotification {
3416                    notification: Some(notification.clone()),
3417                },
3418            ) {
3419                tracing::error!(
3420                    "failed to send notification to {:?} {}",
3421                    connection_id,
3422                    error
3423                );
3424            }
3425        }
3426    }
3427}
3428
3429/// Send a message to the channel
3430async fn send_channel_message(
3431    request: proto::SendChannelMessage,
3432    response: Response<proto::SendChannelMessage>,
3433    session: Session,
3434) -> Result<()> {
3435    // Validate the message body.
3436    let body = request.body.trim().to_string();
3437    if body.len() > MAX_MESSAGE_LEN {
3438        return Err(anyhow!("message is too long"))?;
3439    }
3440    if body.is_empty() {
3441        return Err(anyhow!("message can't be blank"))?;
3442    }
3443
3444    // TODO: adjust mentions if body is trimmed
3445
3446    let timestamp = OffsetDateTime::now_utc();
3447    let nonce = request
3448        .nonce
3449        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3450
3451    let channel_id = ChannelId::from_proto(request.channel_id);
3452    let CreatedChannelMessage {
3453        message_id,
3454        participant_connection_ids,
3455        notifications,
3456    } = session
3457        .db()
3458        .await
3459        .create_channel_message(
3460            channel_id,
3461            session.user_id(),
3462            &body,
3463            &request.mentions,
3464            timestamp,
3465            nonce.clone().into(),
3466            request.reply_to_message_id.map(MessageId::from_proto),
3467        )
3468        .await?;
3469
3470    let message = proto::ChannelMessage {
3471        sender_id: session.user_id().to_proto(),
3472        id: message_id.to_proto(),
3473        body,
3474        mentions: request.mentions,
3475        timestamp: timestamp.unix_timestamp() as u64,
3476        nonce: Some(nonce),
3477        reply_to_message_id: request.reply_to_message_id,
3478        edited_at: None,
3479    };
3480    broadcast(
3481        Some(session.connection_id),
3482        participant_connection_ids.clone(),
3483        |connection| {
3484            session.peer.send(
3485                connection,
3486                proto::ChannelMessageSent {
3487                    channel_id: channel_id.to_proto(),
3488                    message: Some(message.clone()),
3489                },
3490            )
3491        },
3492    );
3493    response.send(proto::SendChannelMessageResponse {
3494        message: Some(message),
3495    })?;
3496
3497    let pool = &*session.connection_pool().await;
3498    let non_participants =
3499        pool.channel_connection_ids(channel_id)
3500            .filter_map(|(connection_id, _)| {
3501                if participant_connection_ids.contains(&connection_id) {
3502                    None
3503                } else {
3504                    Some(connection_id)
3505                }
3506            });
3507    broadcast(None, non_participants, |peer_id| {
3508        session.peer.send(
3509            peer_id,
3510            proto::UpdateChannels {
3511                latest_channel_message_ids: vec![proto::ChannelMessageId {
3512                    channel_id: channel_id.to_proto(),
3513                    message_id: message_id.to_proto(),
3514                }],
3515                ..Default::default()
3516            },
3517        )
3518    });
3519    send_notifications(pool, &session.peer, notifications);
3520
3521    Ok(())
3522}
3523
3524/// Delete a channel message
3525async fn remove_channel_message(
3526    request: proto::RemoveChannelMessage,
3527    response: Response<proto::RemoveChannelMessage>,
3528    session: Session,
3529) -> Result<()> {
3530    let channel_id = ChannelId::from_proto(request.channel_id);
3531    let message_id = MessageId::from_proto(request.message_id);
3532    let (connection_ids, existing_notification_ids) = session
3533        .db()
3534        .await
3535        .remove_channel_message(channel_id, message_id, session.user_id())
3536        .await?;
3537
3538    broadcast(
3539        Some(session.connection_id),
3540        connection_ids,
3541        move |connection| {
3542            session.peer.send(connection, request.clone())?;
3543
3544            for notification_id in &existing_notification_ids {
3545                session.peer.send(
3546                    connection,
3547                    proto::DeleteNotification {
3548                        notification_id: (*notification_id).to_proto(),
3549                    },
3550                )?;
3551            }
3552
3553            Ok(())
3554        },
3555    );
3556    response.send(proto::Ack {})?;
3557    Ok(())
3558}
3559
3560async fn update_channel_message(
3561    request: proto::UpdateChannelMessage,
3562    response: Response<proto::UpdateChannelMessage>,
3563    session: Session,
3564) -> Result<()> {
3565    let channel_id = ChannelId::from_proto(request.channel_id);
3566    let message_id = MessageId::from_proto(request.message_id);
3567    let updated_at = OffsetDateTime::now_utc();
3568    let UpdatedChannelMessage {
3569        message_id,
3570        participant_connection_ids,
3571        notifications,
3572        reply_to_message_id,
3573        timestamp,
3574        deleted_mention_notification_ids,
3575        updated_mention_notifications,
3576    } = session
3577        .db()
3578        .await
3579        .update_channel_message(
3580            channel_id,
3581            message_id,
3582            session.user_id(),
3583            request.body.as_str(),
3584            &request.mentions,
3585            updated_at,
3586        )
3587        .await?;
3588
3589    let nonce = request
3590        .nonce
3591        .clone()
3592        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3593
3594    let message = proto::ChannelMessage {
3595        sender_id: session.user_id().to_proto(),
3596        id: message_id.to_proto(),
3597        body: request.body.clone(),
3598        mentions: request.mentions.clone(),
3599        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3600        nonce: Some(nonce),
3601        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3602        edited_at: Some(updated_at.unix_timestamp() as u64),
3603    };
3604
3605    response.send(proto::Ack {})?;
3606
3607    let pool = &*session.connection_pool().await;
3608    broadcast(
3609        Some(session.connection_id),
3610        participant_connection_ids,
3611        |connection| {
3612            session.peer.send(
3613                connection,
3614                proto::ChannelMessageUpdate {
3615                    channel_id: channel_id.to_proto(),
3616                    message: Some(message.clone()),
3617                },
3618            )?;
3619
3620            for notification_id in &deleted_mention_notification_ids {
3621                session.peer.send(
3622                    connection,
3623                    proto::DeleteNotification {
3624                        notification_id: (*notification_id).to_proto(),
3625                    },
3626                )?;
3627            }
3628
3629            for notification in &updated_mention_notifications {
3630                session.peer.send(
3631                    connection,
3632                    proto::UpdateNotification {
3633                        notification: Some(notification.clone()),
3634                    },
3635                )?;
3636            }
3637
3638            Ok(())
3639        },
3640    );
3641
3642    send_notifications(pool, &session.peer, notifications);
3643
3644    Ok(())
3645}
3646
3647/// Mark a channel message as read
3648async fn acknowledge_channel_message(
3649    request: proto::AckChannelMessage,
3650    session: Session,
3651) -> Result<()> {
3652    let channel_id = ChannelId::from_proto(request.channel_id);
3653    let message_id = MessageId::from_proto(request.message_id);
3654    let notifications = session
3655        .db()
3656        .await
3657        .observe_channel_message(channel_id, session.user_id(), message_id)
3658        .await?;
3659    send_notifications(
3660        &*session.connection_pool().await,
3661        &session.peer,
3662        notifications,
3663    );
3664    Ok(())
3665}
3666
3667/// Mark a buffer version as synced
3668async fn acknowledge_buffer_version(
3669    request: proto::AckBufferOperation,
3670    session: Session,
3671) -> Result<()> {
3672    let buffer_id = BufferId::from_proto(request.buffer_id);
3673    session
3674        .db()
3675        .await
3676        .observe_buffer_version(
3677            buffer_id,
3678            session.user_id(),
3679            request.epoch as i32,
3680            &request.version,
3681        )
3682        .await?;
3683    Ok(())
3684}
3685
3686/// Get a Supermaven API key for the user
3687async fn get_supermaven_api_key(
3688    _request: proto::GetSupermavenApiKey,
3689    response: Response<proto::GetSupermavenApiKey>,
3690    session: Session,
3691) -> Result<()> {
3692    let user_id: String = session.user_id().to_string();
3693    if !session.is_staff() {
3694        return Err(anyhow!("supermaven not enabled for this account"))?;
3695    }
3696
3697    let email = session
3698        .email()
3699        .ok_or_else(|| anyhow!("user must have an email"))?;
3700
3701    let supermaven_admin_api = session
3702        .supermaven_client
3703        .as_ref()
3704        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3705
3706    let result = supermaven_admin_api
3707        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3708        .await?;
3709
3710    response.send(proto::GetSupermavenApiKeyResponse {
3711        api_key: result.api_key,
3712    })?;
3713
3714    Ok(())
3715}
3716
3717/// Start receiving chat updates for a channel
3718async fn join_channel_chat(
3719    request: proto::JoinChannelChat,
3720    response: Response<proto::JoinChannelChat>,
3721    session: Session,
3722) -> Result<()> {
3723    let channel_id = ChannelId::from_proto(request.channel_id);
3724
3725    let db = session.db().await;
3726    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3727        .await?;
3728    let messages = db
3729        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3730        .await?;
3731    response.send(proto::JoinChannelChatResponse {
3732        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3733        messages,
3734    })?;
3735    Ok(())
3736}
3737
3738/// Stop receiving chat updates for a channel
3739async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3740    let channel_id = ChannelId::from_proto(request.channel_id);
3741    session
3742        .db()
3743        .await
3744        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3745        .await?;
3746    Ok(())
3747}
3748
3749/// Retrieve the chat history for a channel
3750async fn get_channel_messages(
3751    request: proto::GetChannelMessages,
3752    response: Response<proto::GetChannelMessages>,
3753    session: Session,
3754) -> Result<()> {
3755    let channel_id = ChannelId::from_proto(request.channel_id);
3756    let messages = session
3757        .db()
3758        .await
3759        .get_channel_messages(
3760            channel_id,
3761            session.user_id(),
3762            MESSAGE_COUNT_PER_PAGE,
3763            Some(MessageId::from_proto(request.before_message_id)),
3764        )
3765        .await?;
3766    response.send(proto::GetChannelMessagesResponse {
3767        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3768        messages,
3769    })?;
3770    Ok(())
3771}
3772
3773/// Retrieve specific chat messages
3774async fn get_channel_messages_by_id(
3775    request: proto::GetChannelMessagesById,
3776    response: Response<proto::GetChannelMessagesById>,
3777    session: Session,
3778) -> Result<()> {
3779    let message_ids = request
3780        .message_ids
3781        .iter()
3782        .map(|id| MessageId::from_proto(*id))
3783        .collect::<Vec<_>>();
3784    let messages = session
3785        .db()
3786        .await
3787        .get_channel_messages_by_id(session.user_id(), &message_ids)
3788        .await?;
3789    response.send(proto::GetChannelMessagesResponse {
3790        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3791        messages,
3792    })?;
3793    Ok(())
3794}
3795
3796/// Retrieve the current users notifications
3797async fn get_notifications(
3798    request: proto::GetNotifications,
3799    response: Response<proto::GetNotifications>,
3800    session: Session,
3801) -> Result<()> {
3802    let notifications = session
3803        .db()
3804        .await
3805        .get_notifications(
3806            session.user_id(),
3807            NOTIFICATION_COUNT_PER_PAGE,
3808            request.before_id.map(db::NotificationId::from_proto),
3809        )
3810        .await?;
3811    response.send(proto::GetNotificationsResponse {
3812        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3813        notifications,
3814    })?;
3815    Ok(())
3816}
3817
3818/// Mark notifications as read
3819async fn mark_notification_as_read(
3820    request: proto::MarkNotificationRead,
3821    response: Response<proto::MarkNotificationRead>,
3822    session: Session,
3823) -> Result<()> {
3824    let database = &session.db().await;
3825    let notifications = database
3826        .mark_notification_as_read_by_id(
3827            session.user_id(),
3828            NotificationId::from_proto(request.notification_id),
3829        )
3830        .await?;
3831    send_notifications(
3832        &*session.connection_pool().await,
3833        &session.peer,
3834        notifications,
3835    );
3836    response.send(proto::Ack {})?;
3837    Ok(())
3838}
3839
3840/// Get the current users information
3841async fn get_private_user_info(
3842    _request: proto::GetPrivateUserInfo,
3843    response: Response<proto::GetPrivateUserInfo>,
3844    session: Session,
3845) -> Result<()> {
3846    let db = session.db().await;
3847
3848    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3849    let user = db
3850        .get_user_by_id(session.user_id())
3851        .await?
3852        .ok_or_else(|| anyhow!("user not found"))?;
3853    let flags = db.get_user_flags(session.user_id()).await?;
3854
3855    response.send(proto::GetPrivateUserInfoResponse {
3856        metrics_id,
3857        staff: user.admin,
3858        flags,
3859        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3860    })?;
3861    Ok(())
3862}
3863
3864/// Accept the terms of service (tos) on behalf of the current user
3865async fn accept_terms_of_service(
3866    _request: proto::AcceptTermsOfService,
3867    response: Response<proto::AcceptTermsOfService>,
3868    session: Session,
3869) -> Result<()> {
3870    let db = session.db().await;
3871
3872    let accepted_tos_at = Utc::now();
3873    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3874        .await?;
3875
3876    response.send(proto::AcceptTermsOfServiceResponse {
3877        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3878    })?;
3879    Ok(())
3880}
3881
3882/// The minimum account age an account must have in order to use the LLM service.
3883pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3884
3885async fn get_llm_api_token(
3886    _request: proto::GetLlmToken,
3887    response: Response<proto::GetLlmToken>,
3888    session: Session,
3889) -> Result<()> {
3890    let db = session.db().await;
3891
3892    let flags = db.get_user_flags(session.user_id()).await?;
3893    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3894
3895    if !session.is_staff() && !has_language_models_feature_flag {
3896        Err(anyhow!("permission denied"))?
3897    }
3898
3899    let user_id = session.user_id();
3900    let user = db
3901        .get_user_by_id(user_id)
3902        .await?
3903        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3904
3905    if user.accepted_tos_at.is_none() {
3906        Err(anyhow!("terms of service not accepted"))?
3907    }
3908
3909    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
3910    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
3911    let billing_preferences = db.get_billing_preferences(user.id).await?;
3912
3913    let token = LlmTokenClaims::create(
3914        &user,
3915        session.is_staff(),
3916        billing_preferences,
3917        &flags,
3918        has_legacy_llm_subscription,
3919        billing_subscription,
3920        session.system_id.clone(),
3921        &session.app_state.config,
3922    )?;
3923    response.send(proto::GetLlmTokenResponse { token })?;
3924    Ok(())
3925}
3926
3927fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3928    let message = match message {
3929        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3930        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3931        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3932        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3933        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3934            code: frame.code.into(),
3935            reason: frame.reason.as_str().to_owned().into(),
3936        })),
3937        // We should never receive a frame while reading the message, according
3938        // to the `tungstenite` maintainers:
3939        //
3940        // > It cannot occur when you read messages from the WebSocket, but it
3941        // > can be used when you want to send the raw frames (e.g. you want to
3942        // > send the frames to the WebSocket without composing the full message first).
3943        // >
3944        // > — https://github.com/snapview/tungstenite-rs/issues/268
3945        TungsteniteMessage::Frame(_) => {
3946            bail!("received an unexpected frame while reading the message")
3947        }
3948    };
3949
3950    Ok(message)
3951}
3952
3953fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3954    match message {
3955        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3956        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3957        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3958        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3959        AxumMessage::Close(frame) => {
3960            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3961                code: frame.code.into(),
3962                reason: frame.reason.as_ref().into(),
3963            }))
3964        }
3965    }
3966}
3967
3968fn notify_membership_updated(
3969    connection_pool: &mut ConnectionPool,
3970    result: MembershipUpdated,
3971    user_id: UserId,
3972    peer: &Peer,
3973) {
3974    for membership in &result.new_channels.channel_memberships {
3975        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3976    }
3977    for channel_id in &result.removed_channels {
3978        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3979    }
3980
3981    let user_channels_update = proto::UpdateUserChannels {
3982        channel_memberships: result
3983            .new_channels
3984            .channel_memberships
3985            .iter()
3986            .map(|cm| proto::ChannelMembership {
3987                channel_id: cm.channel_id.to_proto(),
3988                role: cm.role.into(),
3989            })
3990            .collect(),
3991        ..Default::default()
3992    };
3993
3994    let mut update = build_channels_update(result.new_channels);
3995    update.delete_channels = result
3996        .removed_channels
3997        .into_iter()
3998        .map(|id| id.to_proto())
3999        .collect();
4000    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4001
4002    for connection_id in connection_pool.user_connection_ids(user_id) {
4003        peer.send(connection_id, user_channels_update.clone())
4004            .trace_err();
4005        peer.send(connection_id, update.clone()).trace_err();
4006    }
4007}
4008
4009fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4010    proto::UpdateUserChannels {
4011        channel_memberships: channels
4012            .channel_memberships
4013            .iter()
4014            .map(|m| proto::ChannelMembership {
4015                channel_id: m.channel_id.to_proto(),
4016                role: m.role.into(),
4017            })
4018            .collect(),
4019        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4020        observed_channel_message_id: channels.observed_channel_messages.clone(),
4021    }
4022}
4023
4024fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4025    let mut update = proto::UpdateChannels::default();
4026
4027    for channel in channels.channels {
4028        update.channels.push(channel.to_proto());
4029    }
4030
4031    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4032    update.latest_channel_message_ids = channels.latest_channel_messages;
4033
4034    for (channel_id, participants) in channels.channel_participants {
4035        update
4036            .channel_participants
4037            .push(proto::ChannelParticipants {
4038                channel_id: channel_id.to_proto(),
4039                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4040            });
4041    }
4042
4043    for channel in channels.invited_channels {
4044        update.channel_invitations.push(channel.to_proto());
4045    }
4046
4047    update
4048}
4049
4050fn build_initial_contacts_update(
4051    contacts: Vec<db::Contact>,
4052    pool: &ConnectionPool,
4053) -> proto::UpdateContacts {
4054    let mut update = proto::UpdateContacts::default();
4055
4056    for contact in contacts {
4057        match contact {
4058            db::Contact::Accepted { user_id, busy } => {
4059                update.contacts.push(contact_for_user(user_id, busy, pool));
4060            }
4061            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4062            db::Contact::Incoming { user_id } => {
4063                update
4064                    .incoming_requests
4065                    .push(proto::IncomingContactRequest {
4066                        requester_id: user_id.to_proto(),
4067                    })
4068            }
4069        }
4070    }
4071
4072    update
4073}
4074
4075fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4076    proto::Contact {
4077        user_id: user_id.to_proto(),
4078        online: pool.is_user_online(user_id),
4079        busy,
4080    }
4081}
4082
4083fn room_updated(room: &proto::Room, peer: &Peer) {
4084    broadcast(
4085        None,
4086        room.participants
4087            .iter()
4088            .filter_map(|participant| Some(participant.peer_id?.into())),
4089        |peer_id| {
4090            peer.send(
4091                peer_id,
4092                proto::RoomUpdated {
4093                    room: Some(room.clone()),
4094                },
4095            )
4096        },
4097    );
4098}
4099
4100fn channel_updated(
4101    channel: &db::channel::Model,
4102    room: &proto::Room,
4103    peer: &Peer,
4104    pool: &ConnectionPool,
4105) {
4106    let participants = room
4107        .participants
4108        .iter()
4109        .map(|p| p.user_id)
4110        .collect::<Vec<_>>();
4111
4112    broadcast(
4113        None,
4114        pool.channel_connection_ids(channel.root_id())
4115            .filter_map(|(channel_id, role)| {
4116                role.can_see_channel(channel.visibility)
4117                    .then_some(channel_id)
4118            }),
4119        |peer_id| {
4120            peer.send(
4121                peer_id,
4122                proto::UpdateChannels {
4123                    channel_participants: vec![proto::ChannelParticipants {
4124                        channel_id: channel.id.to_proto(),
4125                        participant_user_ids: participants.clone(),
4126                    }],
4127                    ..Default::default()
4128                },
4129            )
4130        },
4131    );
4132}
4133
4134async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4135    let db = session.db().await;
4136
4137    let contacts = db.get_contacts(user_id).await?;
4138    let busy = db.is_user_busy(user_id).await?;
4139
4140    let pool = session.connection_pool().await;
4141    let updated_contact = contact_for_user(user_id, busy, &pool);
4142    for contact in contacts {
4143        if let db::Contact::Accepted {
4144            user_id: contact_user_id,
4145            ..
4146        } = contact
4147        {
4148            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4149                session
4150                    .peer
4151                    .send(
4152                        contact_conn_id,
4153                        proto::UpdateContacts {
4154                            contacts: vec![updated_contact.clone()],
4155                            remove_contacts: Default::default(),
4156                            incoming_requests: Default::default(),
4157                            remove_incoming_requests: Default::default(),
4158                            outgoing_requests: Default::default(),
4159                            remove_outgoing_requests: Default::default(),
4160                        },
4161                    )
4162                    .trace_err();
4163            }
4164        }
4165    }
4166    Ok(())
4167}
4168
4169async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4170    let mut contacts_to_update = HashSet::default();
4171
4172    let room_id;
4173    let canceled_calls_to_user_ids;
4174    let livekit_room;
4175    let delete_livekit_room;
4176    let room;
4177    let channel;
4178
4179    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4180        contacts_to_update.insert(session.user_id());
4181
4182        for project in left_room.left_projects.values() {
4183            project_left(project, session);
4184        }
4185
4186        room_id = RoomId::from_proto(left_room.room.id);
4187        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4188        livekit_room = mem::take(&mut left_room.room.livekit_room);
4189        delete_livekit_room = left_room.deleted;
4190        room = mem::take(&mut left_room.room);
4191        channel = mem::take(&mut left_room.channel);
4192
4193        room_updated(&room, &session.peer);
4194    } else {
4195        return Ok(());
4196    }
4197
4198    if let Some(channel) = channel {
4199        channel_updated(
4200            &channel,
4201            &room,
4202            &session.peer,
4203            &*session.connection_pool().await,
4204        );
4205    }
4206
4207    {
4208        let pool = session.connection_pool().await;
4209        for canceled_user_id in canceled_calls_to_user_ids {
4210            for connection_id in pool.user_connection_ids(canceled_user_id) {
4211                session
4212                    .peer
4213                    .send(
4214                        connection_id,
4215                        proto::CallCanceled {
4216                            room_id: room_id.to_proto(),
4217                        },
4218                    )
4219                    .trace_err();
4220            }
4221            contacts_to_update.insert(canceled_user_id);
4222        }
4223    }
4224
4225    for contact_user_id in contacts_to_update {
4226        update_user_contacts(contact_user_id, session).await?;
4227    }
4228
4229    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4230        live_kit
4231            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4232            .await
4233            .trace_err();
4234
4235        if delete_livekit_room {
4236            live_kit.delete_room(livekit_room).await.trace_err();
4237        }
4238    }
4239
4240    Ok(())
4241}
4242
4243async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4244    let left_channel_buffers = session
4245        .db()
4246        .await
4247        .leave_channel_buffers(session.connection_id)
4248        .await?;
4249
4250    for left_buffer in left_channel_buffers {
4251        channel_buffer_updated(
4252            session.connection_id,
4253            left_buffer.connections,
4254            &proto::UpdateChannelBufferCollaborators {
4255                channel_id: left_buffer.channel_id.to_proto(),
4256                collaborators: left_buffer.collaborators,
4257            },
4258            &session.peer,
4259        );
4260    }
4261
4262    Ok(())
4263}
4264
4265fn project_left(project: &db::LeftProject, session: &Session) {
4266    for connection_id in &project.connection_ids {
4267        if project.should_unshare {
4268            session
4269                .peer
4270                .send(
4271                    *connection_id,
4272                    proto::UnshareProject {
4273                        project_id: project.id.to_proto(),
4274                    },
4275                )
4276                .trace_err();
4277        } else {
4278            session
4279                .peer
4280                .send(
4281                    *connection_id,
4282                    proto::RemoveProjectCollaborator {
4283                        project_id: project.id.to_proto(),
4284                        peer_id: Some(session.connection_id.into()),
4285                    },
4286                )
4287                .trace_err();
4288        }
4289    }
4290}
4291
4292pub trait ResultExt {
4293    type Ok;
4294
4295    fn trace_err(self) -> Option<Self::Ok>;
4296}
4297
4298impl<T, E> ResultExt for Result<T, E>
4299where
4300    E: std::fmt::Debug,
4301{
4302    type Ok = T;
4303
4304    #[track_caller]
4305    fn trace_err(self) -> Option<T> {
4306        match self {
4307            Ok(value) => Some(value),
4308            Err(error) => {
4309                tracing::error!("{:?}", error);
4310                None
4311            }
4312        }
4313    }
4314}