rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::LlmTokenClaims;
   6use crate::{
   7    AppState, Error, Result, auth,
   8    db::{
   9        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  10        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  11        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  12        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15};
  16use anyhow::{Context as _, anyhow, bail};
  17use async_tungstenite::tungstenite::{
  18    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  19};
  20use axum::{
  21    Extension, Router, TypedHeader,
  22    body::Body,
  23    extract::{
  24        ConnectInfo, WebSocketUpgrade,
  25        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{MutexGuard, Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone, Debug)]
 105pub enum Principal {
 106    User(User),
 107    Impersonated { user: User, admin: User },
 108}
 109
 110impl Principal {
 111    fn update_span(&self, span: &tracing::Span) {
 112        match &self {
 113            Principal::User(user) => {
 114                span.record("user_id", user.id.0);
 115                span.record("login", &user.github_login);
 116            }
 117            Principal::Impersonated { user, admin } => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120                span.record("impersonator", &admin.github_login);
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Clone)]
 127struct Session {
 128    principal: Principal,
 129    connection_id: ConnectionId,
 130    db: Arc<tokio::sync::Mutex<DbHandle>>,
 131    peer: Arc<Peer>,
 132    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 133    app_state: Arc<AppState>,
 134    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 135    /// The GeoIP country code for the user.
 136    #[allow(unused)]
 137    geoip_country_code: Option<String>,
 138    system_id: Option<String>,
 139    _executor: Executor,
 140}
 141
 142impl Session {
 143    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 144        #[cfg(test)]
 145        tokio::task::yield_now().await;
 146        let guard = self.db.lock().await;
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        guard
 150    }
 151
 152    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.connection_pool.lock();
 156        ConnectionPoolGuard {
 157            guard,
 158            _not_send: PhantomData,
 159        }
 160    }
 161
 162    fn is_staff(&self) -> bool {
 163        match &self.principal {
 164            Principal::User(user) => user.admin,
 165            Principal::Impersonated { .. } => true,
 166        }
 167    }
 168
 169    pub async fn has_llm_subscription(
 170        &self,
 171        db: &MutexGuard<'_, DbHandle>,
 172    ) -> anyhow::Result<bool> {
 173        if self.is_staff() {
 174            return Ok(true);
 175        }
 176
 177        let user_id = self.user_id();
 178
 179        Ok(db.has_active_billing_subscription(user_id).await?)
 180    }
 181
 182    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 183        let user_id = self.user_id();
 184
 185        let subscription = db.get_active_billing_subscription(user_id).await?;
 186        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
 187
 188        let plan = if let Some(subscription_kind) = subscription_kind {
 189            match subscription_kind {
 190                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
 191                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
 192                SubscriptionKind::ZedFree => proto::Plan::Free,
 193            }
 194        } else {
 195            proto::Plan::Free
 196        };
 197
 198        Ok(plan)
 199    }
 200
 201    fn user_id(&self) -> UserId {
 202        match &self.principal {
 203            Principal::User(user) => user.id,
 204            Principal::Impersonated { user, .. } => user.id,
 205        }
 206    }
 207
 208    pub fn email(&self) -> Option<String> {
 209        match &self.principal {
 210            Principal::User(user) => user.email_address.clone(),
 211            Principal::Impersonated { user, .. } => user.email_address.clone(),
 212        }
 213    }
 214}
 215
 216impl Debug for Session {
 217    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 218        let mut result = f.debug_struct("Session");
 219        match &self.principal {
 220            Principal::User(user) => {
 221                result.field("user", &user.github_login);
 222            }
 223            Principal::Impersonated { user, admin } => {
 224                result.field("user", &user.github_login);
 225                result.field("impersonator", &admin.github_login);
 226            }
 227        }
 228        result.field("connection_id", &self.connection_id).finish()
 229    }
 230}
 231
 232struct DbHandle(Arc<Database>);
 233
 234impl Deref for DbHandle {
 235    type Target = Database;
 236
 237    fn deref(&self) -> &Self::Target {
 238        self.0.as_ref()
 239    }
 240}
 241
 242pub struct Server {
 243    id: parking_lot::Mutex<ServerId>,
 244    peer: Arc<Peer>,
 245    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 246    app_state: Arc<AppState>,
 247    handlers: HashMap<TypeId, MessageHandler>,
 248    teardown: watch::Sender<bool>,
 249}
 250
 251pub(crate) struct ConnectionPoolGuard<'a> {
 252    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 253    _not_send: PhantomData<Rc<()>>,
 254}
 255
 256#[derive(Serialize)]
 257pub struct ServerSnapshot<'a> {
 258    peer: &'a Peer,
 259    #[serde(serialize_with = "serialize_deref")]
 260    connection_pool: ConnectionPoolGuard<'a>,
 261}
 262
 263pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 264where
 265    S: Serializer,
 266    T: Deref<Target = U>,
 267    U: Serialize,
 268{
 269    Serialize::serialize(value.deref(), serializer)
 270}
 271
 272impl Server {
 273    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 274        let mut server = Self {
 275            id: parking_lot::Mutex::new(id),
 276            peer: Peer::new(id.0 as u32),
 277            app_state: app_state.clone(),
 278            connection_pool: Default::default(),
 279            handlers: Default::default(),
 280            teardown: watch::channel(false).0,
 281        };
 282
 283        server
 284            .add_request_handler(ping)
 285            .add_request_handler(create_room)
 286            .add_request_handler(join_room)
 287            .add_request_handler(rejoin_room)
 288            .add_request_handler(leave_room)
 289            .add_request_handler(set_room_participant_role)
 290            .add_request_handler(call)
 291            .add_request_handler(cancel_call)
 292            .add_message_handler(decline_call)
 293            .add_request_handler(update_participant_location)
 294            .add_request_handler(share_project)
 295            .add_message_handler(unshare_project)
 296            .add_request_handler(join_project)
 297            .add_message_handler(leave_project)
 298            .add_request_handler(update_project)
 299            .add_request_handler(update_worktree)
 300            .add_request_handler(update_repository)
 301            .add_request_handler(remove_repository)
 302            .add_message_handler(start_language_server)
 303            .add_message_handler(update_language_server)
 304            .add_message_handler(update_diagnostic_summary)
 305            .add_message_handler(update_worktree_settings)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 309            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 310            .add_request_handler(forward_find_search_candidates_request)
 311            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 316            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 317            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 318            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 320            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 326            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 327            .add_request_handler(
 328                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 329            )
 330            .add_request_handler(
 331                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 332            )
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 337            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 342            .add_request_handler(
 343                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 344            )
 345            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 346            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 347            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 348            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 349            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 351            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 352            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 353            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 354            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 355            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 356            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 357            .add_request_handler(
 358                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 359            )
 360            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 361            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 362            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 363            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 365            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 366            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 367            .add_message_handler(create_buffer_for_peer)
 368            .add_request_handler(update_buffer)
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 375            .add_request_handler(get_users)
 376            .add_request_handler(fuzzy_search_users)
 377            .add_request_handler(request_contact)
 378            .add_request_handler(remove_contact)
 379            .add_request_handler(respond_to_contact_request)
 380            .add_message_handler(subscribe_to_channels)
 381            .add_request_handler(create_channel)
 382            .add_request_handler(delete_channel)
 383            .add_request_handler(invite_channel_member)
 384            .add_request_handler(remove_channel_member)
 385            .add_request_handler(set_channel_member_role)
 386            .add_request_handler(set_channel_visibility)
 387            .add_request_handler(rename_channel)
 388            .add_request_handler(join_channel_buffer)
 389            .add_request_handler(leave_channel_buffer)
 390            .add_message_handler(update_channel_buffer)
 391            .add_request_handler(rejoin_channel_buffers)
 392            .add_request_handler(get_channel_members)
 393            .add_request_handler(respond_to_channel_invite)
 394            .add_request_handler(join_channel)
 395            .add_request_handler(join_channel_chat)
 396            .add_message_handler(leave_channel_chat)
 397            .add_request_handler(send_channel_message)
 398            .add_request_handler(remove_channel_message)
 399            .add_request_handler(update_channel_message)
 400            .add_request_handler(get_channel_messages)
 401            .add_request_handler(get_channel_messages_by_id)
 402            .add_request_handler(get_notifications)
 403            .add_request_handler(mark_notification_as_read)
 404            .add_request_handler(move_channel)
 405            .add_request_handler(follow)
 406            .add_message_handler(unfollow)
 407            .add_message_handler(update_followers)
 408            .add_request_handler(get_private_user_info)
 409            .add_request_handler(get_llm_api_token)
 410            .add_request_handler(accept_terms_of_service)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(get_supermaven_api_key)
 414            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 415            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 416            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 417            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 419            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 433            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 434            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 435            .add_message_handler(update_context);
 436
 437        Arc::new(server)
 438    }
 439
 440    pub async fn start(&self) -> Result<()> {
 441        let server_id = *self.id.lock();
 442        let app_state = self.app_state.clone();
 443        let peer = self.peer.clone();
 444        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 445        let pool = self.connection_pool.clone();
 446        let livekit_client = self.app_state.livekit_client.clone();
 447
 448        let span = info_span!("start server");
 449        self.app_state.executor.spawn_detached(
 450            async move {
 451                tracing::info!("waiting for cleanup timeout");
 452                timeout.await;
 453                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 454                if let Some((room_ids, channel_ids)) = app_state
 455                    .db
 456                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 457                    .await
 458                    .trace_err()
 459                {
 460                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 461                    tracing::info!(
 462                        stale_channel_buffer_count = channel_ids.len(),
 463                        "retrieved stale channel buffers"
 464                    );
 465
 466                    for channel_id in channel_ids {
 467                        if let Some(refreshed_channel_buffer) = app_state
 468                            .db
 469                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 470                            .await
 471                            .trace_err()
 472                        {
 473                            for connection_id in refreshed_channel_buffer.connection_ids {
 474                                peer.send(
 475                                    connection_id,
 476                                    proto::UpdateChannelBufferCollaborators {
 477                                        channel_id: channel_id.to_proto(),
 478                                        collaborators: refreshed_channel_buffer
 479                                            .collaborators
 480                                            .clone(),
 481                                    },
 482                                )
 483                                .trace_err();
 484                            }
 485                        }
 486                    }
 487
 488                    for room_id in room_ids {
 489                        let mut contacts_to_update = HashSet::default();
 490                        let mut canceled_calls_to_user_ids = Vec::new();
 491                        let mut livekit_room = String::new();
 492                        let mut delete_livekit_room = false;
 493
 494                        if let Some(mut refreshed_room) = app_state
 495                            .db
 496                            .clear_stale_room_participants(room_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            tracing::info!(
 501                                room_id = room_id.0,
 502                                new_participant_count = refreshed_room.room.participants.len(),
 503                                "refreshed room"
 504                            );
 505                            room_updated(&refreshed_room.room, &peer);
 506                            if let Some(channel) = refreshed_room.channel.as_ref() {
 507                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 508                            }
 509                            contacts_to_update
 510                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 511                            contacts_to_update
 512                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 513                            canceled_calls_to_user_ids =
 514                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 515                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 516                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 517                        }
 518
 519                        {
 520                            let pool = pool.lock();
 521                            for canceled_user_id in canceled_calls_to_user_ids {
 522                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 523                                    peer.send(
 524                                        connection_id,
 525                                        proto::CallCanceled {
 526                                            room_id: room_id.to_proto(),
 527                                        },
 528                                    )
 529                                    .trace_err();
 530                                }
 531                            }
 532                        }
 533
 534                        for user_id in contacts_to_update {
 535                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 536                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 537                            if let Some((busy, contacts)) = busy.zip(contacts) {
 538                                let pool = pool.lock();
 539                                let updated_contact = contact_for_user(user_id, busy, &pool);
 540                                for contact in contacts {
 541                                    if let db::Contact::Accepted {
 542                                        user_id: contact_user_id,
 543                                        ..
 544                                    } = contact
 545                                    {
 546                                        for contact_conn_id in
 547                                            pool.user_connection_ids(contact_user_id)
 548                                        {
 549                                            peer.send(
 550                                                contact_conn_id,
 551                                                proto::UpdateContacts {
 552                                                    contacts: vec![updated_contact.clone()],
 553                                                    remove_contacts: Default::default(),
 554                                                    incoming_requests: Default::default(),
 555                                                    remove_incoming_requests: Default::default(),
 556                                                    outgoing_requests: Default::default(),
 557                                                    remove_outgoing_requests: Default::default(),
 558                                                },
 559                                            )
 560                                            .trace_err();
 561                                        }
 562                                    }
 563                                }
 564                            }
 565                        }
 566
 567                        if let Some(live_kit) = livekit_client.as_ref() {
 568                            if delete_livekit_room {
 569                                live_kit.delete_room(livekit_room).await.trace_err();
 570                            }
 571                        }
 572                    }
 573                }
 574
 575                app_state
 576                    .db
 577                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 578                    .await
 579                    .trace_err();
 580            }
 581            .instrument(span),
 582        );
 583        Ok(())
 584    }
 585
 586    pub fn teardown(&self) {
 587        self.peer.teardown();
 588        self.connection_pool.lock().reset();
 589        let _ = self.teardown.send(true);
 590    }
 591
 592    #[cfg(test)]
 593    pub fn reset(&self, id: ServerId) {
 594        self.teardown();
 595        *self.id.lock() = id;
 596        self.peer.reset(id.0 as u32);
 597        let _ = self.teardown.send(false);
 598    }
 599
 600    #[cfg(test)]
 601    pub fn id(&self) -> ServerId {
 602        *self.id.lock()
 603    }
 604
 605    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 606    where
 607        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 608        Fut: 'static + Send + Future<Output = Result<()>>,
 609        M: EnvelopedMessage,
 610    {
 611        let prev_handler = self.handlers.insert(
 612            TypeId::of::<M>(),
 613            Box::new(move |envelope, session| {
 614                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 615                let received_at = envelope.received_at;
 616                tracing::info!("message received");
 617                let start_time = Instant::now();
 618                let future = (handler)(*envelope, session);
 619                async move {
 620                    let result = future.await;
 621                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 622                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 623                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 624                    let payload_type = M::NAME;
 625
 626                    match result {
 627                        Err(error) => {
 628                            tracing::error!(
 629                                ?error,
 630                                total_duration_ms,
 631                                processing_duration_ms,
 632                                queue_duration_ms,
 633                                payload_type,
 634                                "error handling message"
 635                            )
 636                        }
 637                        Ok(()) => tracing::info!(
 638                            total_duration_ms,
 639                            processing_duration_ms,
 640                            queue_duration_ms,
 641                            "finished handling message"
 642                        ),
 643                    }
 644                }
 645                .boxed()
 646            }),
 647        );
 648        if prev_handler.is_some() {
 649            panic!("registered a handler for the same message twice");
 650        }
 651        self
 652    }
 653
 654    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 655    where
 656        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 657        Fut: 'static + Send + Future<Output = Result<()>>,
 658        M: EnvelopedMessage,
 659    {
 660        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 661        self
 662    }
 663
 664    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 665    where
 666        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 667        Fut: Send + Future<Output = Result<()>>,
 668        M: RequestMessage,
 669    {
 670        let handler = Arc::new(handler);
 671        self.add_handler(move |envelope, session| {
 672            let receipt = envelope.receipt();
 673            let handler = handler.clone();
 674            async move {
 675                let peer = session.peer.clone();
 676                let responded = Arc::new(AtomicBool::default());
 677                let response = Response {
 678                    peer: peer.clone(),
 679                    responded: responded.clone(),
 680                    receipt,
 681                };
 682                match (handler)(envelope.payload, response, session).await {
 683                    Ok(()) => {
 684                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 685                            Ok(())
 686                        } else {
 687                            Err(anyhow!("handler did not send a response"))?
 688                        }
 689                    }
 690                    Err(error) => {
 691                        let proto_err = match &error {
 692                            Error::Internal(err) => err.to_proto(),
 693                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 694                        };
 695                        peer.respond_with_error(receipt, proto_err)?;
 696                        Err(error)
 697                    }
 698                }
 699            }
 700        })
 701    }
 702
 703    pub fn handle_connection(
 704        self: &Arc<Self>,
 705        connection: Connection,
 706        address: String,
 707        principal: Principal,
 708        zed_version: ZedVersion,
 709        geoip_country_code: Option<String>,
 710        system_id: Option<String>,
 711        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 712        executor: Executor,
 713    ) -> impl Future<Output = ()> + use<> {
 714        let this = self.clone();
 715        let span = info_span!("handle connection", %address,
 716            connection_id=field::Empty,
 717            user_id=field::Empty,
 718            login=field::Empty,
 719            impersonator=field::Empty,
 720            geoip_country_code=field::Empty
 721        );
 722        principal.update_span(&span);
 723        if let Some(country_code) = geoip_country_code.as_ref() {
 724            span.record("geoip_country_code", country_code);
 725        }
 726
 727        let mut teardown = self.teardown.subscribe();
 728        async move {
 729            if *teardown.borrow() {
 730                tracing::error!("server is tearing down");
 731                return
 732            }
 733            let (connection_id, handle_io, mut incoming_rx) = this
 734                .peer
 735                .add_connection(connection, {
 736                    let executor = executor.clone();
 737                    move |duration| executor.sleep(duration)
 738                });
 739            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 740
 741            tracing::info!("connection opened");
 742
 743            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 744            let http_client = match ReqwestClient::user_agent(&user_agent) {
 745                Ok(http_client) => Arc::new(http_client),
 746                Err(error) => {
 747                    tracing::error!(?error, "failed to create HTTP client");
 748                    return;
 749                }
 750            };
 751
 752            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 753                    supermaven_admin_api_key.to_string(),
 754                    http_client.clone(),
 755                )));
 756
 757            let session = Session {
 758                principal: principal.clone(),
 759                connection_id,
 760                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 761                peer: this.peer.clone(),
 762                connection_pool: this.connection_pool.clone(),
 763                app_state: this.app_state.clone(),
 764                geoip_country_code,
 765                system_id,
 766                _executor: executor.clone(),
 767                supermaven_client,
 768            };
 769
 770            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 771                tracing::error!(?error, "failed to send initial client update");
 772                return;
 773            }
 774
 775            let handle_io = handle_io.fuse();
 776            futures::pin_mut!(handle_io);
 777
 778            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 779            // This prevents deadlocks when e.g., client A performs a request to client B and
 780            // client B performs a request to client A. If both clients stop processing further
 781            // messages until their respective request completes, they won't have a chance to
 782            // respond to the other client's request and cause a deadlock.
 783            //
 784            // This arrangement ensures we will attempt to process earlier messages first, but fall
 785            // back to processing messages arrived later in the spirit of making progress.
 786            let mut foreground_message_handlers = FuturesUnordered::new();
 787            let concurrent_handlers = Arc::new(Semaphore::new(256));
 788            loop {
 789                let next_message = async {
 790                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 791                    let message = incoming_rx.next().await;
 792                    (permit, message)
 793                }.fuse();
 794                futures::pin_mut!(next_message);
 795                futures::select_biased! {
 796                    _ = teardown.changed().fuse() => return,
 797                    result = handle_io => {
 798                        if let Err(error) = result {
 799                            tracing::error!(?error, "error handling I/O");
 800                        }
 801                        break;
 802                    }
 803                    _ = foreground_message_handlers.next() => {}
 804                    next_message = next_message => {
 805                        let (permit, message) = next_message;
 806                        if let Some(message) = message {
 807                            let type_name = message.payload_type_name();
 808                            // note: we copy all the fields from the parent span so we can query them in the logs.
 809                            // (https://github.com/tokio-rs/tracing/issues/2670).
 810                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 811                                user_id=field::Empty,
 812                                login=field::Empty,
 813                                impersonator=field::Empty,
 814                            );
 815                            principal.update_span(&span);
 816                            let span_enter = span.enter();
 817                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 818                                let is_background = message.is_background();
 819                                let handle_message = (handler)(message, session.clone());
 820                                drop(span_enter);
 821
 822                                let handle_message = async move {
 823                                    handle_message.await;
 824                                    drop(permit);
 825                                }.instrument(span);
 826                                if is_background {
 827                                    executor.spawn_detached(handle_message);
 828                                } else {
 829                                    foreground_message_handlers.push(handle_message);
 830                                }
 831                            } else {
 832                                tracing::error!("no message handler");
 833                            }
 834                        } else {
 835                            tracing::info!("connection closed");
 836                            break;
 837                        }
 838                    }
 839                }
 840            }
 841
 842            drop(foreground_message_handlers);
 843            tracing::info!("signing out");
 844            if let Err(error) = connection_lost(session, teardown, executor).await {
 845                tracing::error!(?error, "error signing out");
 846            }
 847
 848        }.instrument(span)
 849    }
 850
 851    async fn send_initial_client_update(
 852        &self,
 853        connection_id: ConnectionId,
 854        principal: &Principal,
 855        zed_version: ZedVersion,
 856        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 857        session: &Session,
 858    ) -> Result<()> {
 859        self.peer.send(
 860            connection_id,
 861            proto::Hello {
 862                peer_id: Some(connection_id.into()),
 863            },
 864        )?;
 865        tracing::info!("sent hello message");
 866        if let Some(send_connection_id) = send_connection_id.take() {
 867            let _ = send_connection_id.send(connection_id);
 868        }
 869
 870        match principal {
 871            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 872                if !user.connected_once {
 873                    self.peer.send(connection_id, proto::ShowContacts {})?;
 874                    self.app_state
 875                        .db
 876                        .set_user_connected_once(user.id, true)
 877                        .await?;
 878                }
 879
 880                update_user_plan(user.id, session).await?;
 881
 882                let contacts = self.app_state.db.get_contacts(user.id).await?;
 883
 884                {
 885                    let mut pool = self.connection_pool.lock();
 886                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 887                    self.peer.send(
 888                        connection_id,
 889                        build_initial_contacts_update(contacts, &pool),
 890                    )?;
 891                }
 892
 893                if should_auto_subscribe_to_channels(zed_version) {
 894                    subscribe_user_to_channels(user.id, session).await?;
 895                }
 896
 897                if let Some(incoming_call) =
 898                    self.app_state.db.incoming_call_for_user(user.id).await?
 899                {
 900                    self.peer.send(connection_id, incoming_call)?;
 901                }
 902
 903                update_user_contacts(user.id, session).await?;
 904            }
 905        }
 906
 907        Ok(())
 908    }
 909
 910    pub async fn invite_code_redeemed(
 911        self: &Arc<Self>,
 912        inviter_id: UserId,
 913        invitee_id: UserId,
 914    ) -> Result<()> {
 915        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 916            if let Some(code) = &user.invite_code {
 917                let pool = self.connection_pool.lock();
 918                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 919                for connection_id in pool.user_connection_ids(inviter_id) {
 920                    self.peer.send(
 921                        connection_id,
 922                        proto::UpdateContacts {
 923                            contacts: vec![invitee_contact.clone()],
 924                            ..Default::default()
 925                        },
 926                    )?;
 927                    self.peer.send(
 928                        connection_id,
 929                        proto::UpdateInviteInfo {
 930                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 931                            count: user.invite_count as u32,
 932                        },
 933                    )?;
 934                }
 935            }
 936        }
 937        Ok(())
 938    }
 939
 940    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 941        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 942            if let Some(invite_code) = &user.invite_code {
 943                let pool = self.connection_pool.lock();
 944                for connection_id in pool.user_connection_ids(user_id) {
 945                    self.peer.send(
 946                        connection_id,
 947                        proto::UpdateInviteInfo {
 948                            url: format!(
 949                                "{}{}",
 950                                self.app_state.config.invite_link_prefix, invite_code
 951                            ),
 952                            count: user.invite_count as u32,
 953                        },
 954                    )?;
 955                }
 956            }
 957        }
 958        Ok(())
 959    }
 960
 961    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 962        let pool = self.connection_pool.lock();
 963        for connection_id in pool.user_connection_ids(user_id) {
 964            self.peer
 965                .send(connection_id, proto::RefreshLlmToken {})
 966                .trace_err();
 967        }
 968    }
 969
 970    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 971        ServerSnapshot {
 972            connection_pool: ConnectionPoolGuard {
 973                guard: self.connection_pool.lock(),
 974                _not_send: PhantomData,
 975            },
 976            peer: &self.peer,
 977        }
 978    }
 979}
 980
 981impl Deref for ConnectionPoolGuard<'_> {
 982    type Target = ConnectionPool;
 983
 984    fn deref(&self) -> &Self::Target {
 985        &self.guard
 986    }
 987}
 988
 989impl DerefMut for ConnectionPoolGuard<'_> {
 990    fn deref_mut(&mut self) -> &mut Self::Target {
 991        &mut self.guard
 992    }
 993}
 994
 995impl Drop for ConnectionPoolGuard<'_> {
 996    fn drop(&mut self) {
 997        #[cfg(test)]
 998        self.check_invariants();
 999    }
1000}
1001
1002fn broadcast<F>(
1003    sender_id: Option<ConnectionId>,
1004    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005    mut f: F,
1006) where
1007    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009    for receiver_id in receiver_ids {
1010        if Some(receiver_id) != sender_id {
1011            if let Err(error) = f(receiver_id) {
1012                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013            }
1014        }
1015    }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021    fn name() -> &'static HeaderName {
1022        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024    }
1025
1026    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027    where
1028        Self: Sized,
1029        I: Iterator<Item = &'i axum::http::HeaderValue>,
1030    {
1031        let version = values
1032            .next()
1033            .ok_or_else(axum::headers::Error::invalid)?
1034            .to_str()
1035            .map_err(|_| axum::headers::Error::invalid())?
1036            .parse()
1037            .map_err(|_| axum::headers::Error::invalid())?;
1038        Ok(Self(version))
1039    }
1040
1041    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042        values.extend([self.0.to_string().parse().unwrap()]);
1043    }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048    fn name() -> &'static HeaderName {
1049        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051    }
1052
1053    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054    where
1055        Self: Sized,
1056        I: Iterator<Item = &'i axum::http::HeaderValue>,
1057    {
1058        let version = values
1059            .next()
1060            .ok_or_else(axum::headers::Error::invalid)?
1061            .to_str()
1062            .map_err(|_| axum::headers::Error::invalid())?
1063            .parse()
1064            .map_err(|_| axum::headers::Error::invalid())?;
1065        Ok(Self(version))
1066    }
1067
1068    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069        values.extend([self.0.to_string().parse().unwrap()]);
1070    }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074    Router::new()
1075        .route("/rpc", get(handle_websocket_request))
1076        .layer(
1077            ServiceBuilder::new()
1078                .layer(Extension(server.app_state.clone()))
1079                .layer(middleware::from_fn(auth::validate_header)),
1080        )
1081        .route("/metrics", get(handle_metrics))
1082        .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089    Extension(server): Extension<Arc<Server>>,
1090    Extension(principal): Extension<Principal>,
1091    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093    ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095    if protocol_version != rpc::PROTOCOL_VERSION {
1096        return (
1097            StatusCode::UPGRADE_REQUIRED,
1098            "client must be upgraded".to_string(),
1099        )
1100            .into_response();
1101    }
1102
1103    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1104        return (
1105            StatusCode::UPGRADE_REQUIRED,
1106            "no version header found".to_string(),
1107        )
1108            .into_response();
1109    };
1110
1111    if !version.can_collaborate() {
1112        return (
1113            StatusCode::UPGRADE_REQUIRED,
1114            "client must be upgraded".to_string(),
1115        )
1116            .into_response();
1117    }
1118
1119    let socket_address = socket_address.to_string();
1120    ws.on_upgrade(move |socket| {
1121        let socket = socket
1122            .map_ok(to_tungstenite_message)
1123            .err_into()
1124            .with(|message| async move { to_axum_message(message) });
1125        let connection = Connection::new(Box::pin(socket));
1126        async move {
1127            server
1128                .handle_connection(
1129                    connection,
1130                    socket_address,
1131                    principal,
1132                    version,
1133                    country_code_header.map(|header| header.to_string()),
1134                    system_id_header.map(|header| header.to_string()),
1135                    None,
1136                    Executor::Production,
1137                )
1138                .await;
1139        }
1140    })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145    let connections_metric = CONNECTIONS_METRIC
1146        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148    let connections = server
1149        .connection_pool
1150        .lock()
1151        .connections()
1152        .filter(|connection| !connection.admin)
1153        .count();
1154    connections_metric.set(connections as _);
1155
1156    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158        register_int_gauge!(
1159            "shared_projects",
1160            "number of open projects with one or more guests"
1161        )
1162        .unwrap()
1163    });
1164
1165    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166    shared_projects_metric.set(shared_projects as _);
1167
1168    let encoder = prometheus::TextEncoder::new();
1169    let metric_families = prometheus::gather();
1170    let encoded_metrics = encoder
1171        .encode_to_string(&metric_families)
1172        .map_err(|err| anyhow!("{}", err))?;
1173    Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178    session: Session,
1179    mut teardown: watch::Receiver<bool>,
1180    executor: Executor,
1181) -> Result<()> {
1182    session.peer.disconnect(session.connection_id);
1183    session
1184        .connection_pool()
1185        .await
1186        .remove_connection(session.connection_id)?;
1187
1188    session
1189        .db()
1190        .await
1191        .connection_lost(session.connection_id)
1192        .await
1193        .trace_err();
1194
1195    futures::select_biased! {
1196        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199            leave_room_for_session(&session, session.connection_id).await.trace_err();
1200            leave_channel_buffers_for_session(&session)
1201                .await
1202                .trace_err();
1203
1204            if !session
1205                .connection_pool()
1206                .await
1207                .is_user_online(session.user_id())
1208            {
1209                let db = session.db().await;
1210                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211                    room_updated(&room, &session.peer);
1212                }
1213            }
1214
1215            update_user_contacts(session.user_id(), &session).await?;
1216        },
1217        _ = teardown.changed().fuse() => {}
1218    }
1219
1220    Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225    response.send(proto::Ack {})?;
1226    Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231    _request: proto::CreateRoom,
1232    response: Response<proto::CreateRoom>,
1233    session: Session,
1234) -> Result<()> {
1235    let livekit_room = nanoid::nanoid!(30);
1236
1237    let live_kit_connection_info = util::maybe!(async {
1238        let live_kit = session.app_state.livekit_client.as_ref();
1239        let live_kit = live_kit?;
1240        let user_id = session.user_id().to_string();
1241
1242        let token = live_kit
1243            .room_token(&livekit_room, &user_id.to_string())
1244            .trace_err()?;
1245
1246        Some(proto::LiveKitConnectionInfo {
1247            server_url: live_kit.url().into(),
1248            token,
1249            can_publish: true,
1250        })
1251    })
1252    .await;
1253
1254    let room = session
1255        .db()
1256        .await
1257        .create_room(session.user_id(), session.connection_id, &livekit_room)
1258        .await?;
1259
1260    response.send(proto::CreateRoomResponse {
1261        room: Some(room.clone()),
1262        live_kit_connection_info,
1263    })?;
1264
1265    update_user_contacts(session.user_id(), &session).await?;
1266    Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271    request: proto::JoinRoom,
1272    response: Response<proto::JoinRoom>,
1273    session: Session,
1274) -> Result<()> {
1275    let room_id = RoomId::from_proto(request.id);
1276
1277    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279    if let Some(channel_id) = channel_id {
1280        return join_channel_internal(channel_id, Box::new(response), session).await;
1281    }
1282
1283    let joined_room = {
1284        let room = session
1285            .db()
1286            .await
1287            .join_room(room_id, session.user_id(), session.connection_id)
1288            .await?;
1289        room_updated(&room.room, &session.peer);
1290        room.into_inner()
1291    };
1292
1293    for connection_id in session
1294        .connection_pool()
1295        .await
1296        .user_connection_ids(session.user_id())
1297    {
1298        session
1299            .peer
1300            .send(
1301                connection_id,
1302                proto::CallCanceled {
1303                    room_id: room_id.to_proto(),
1304                },
1305            )
1306            .trace_err();
1307    }
1308
1309    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310    {
1311        live_kit
1312            .room_token(
1313                &joined_room.room.livekit_room,
1314                &session.user_id().to_string(),
1315            )
1316            .trace_err()
1317            .map(|token| proto::LiveKitConnectionInfo {
1318                server_url: live_kit.url().into(),
1319                token,
1320                can_publish: true,
1321            })
1322    } else {
1323        None
1324    };
1325
1326    response.send(proto::JoinRoomResponse {
1327        room: Some(joined_room.room),
1328        channel_id: None,
1329        live_kit_connection_info,
1330    })?;
1331
1332    update_user_contacts(session.user_id(), &session).await?;
1333    Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338    request: proto::RejoinRoom,
1339    response: Response<proto::RejoinRoom>,
1340    session: Session,
1341) -> Result<()> {
1342    let room;
1343    let channel;
1344    {
1345        let mut rejoined_room = session
1346            .db()
1347            .await
1348            .rejoin_room(request, session.user_id(), session.connection_id)
1349            .await?;
1350
1351        response.send(proto::RejoinRoomResponse {
1352            room: Some(rejoined_room.room.clone()),
1353            reshared_projects: rejoined_room
1354                .reshared_projects
1355                .iter()
1356                .map(|project| proto::ResharedProject {
1357                    id: project.id.to_proto(),
1358                    collaborators: project
1359                        .collaborators
1360                        .iter()
1361                        .map(|collaborator| collaborator.to_proto())
1362                        .collect(),
1363                })
1364                .collect(),
1365            rejoined_projects: rejoined_room
1366                .rejoined_projects
1367                .iter()
1368                .map(|rejoined_project| rejoined_project.to_proto())
1369                .collect(),
1370        })?;
1371        room_updated(&rejoined_room.room, &session.peer);
1372
1373        for project in &rejoined_room.reshared_projects {
1374            for collaborator in &project.collaborators {
1375                session
1376                    .peer
1377                    .send(
1378                        collaborator.connection_id,
1379                        proto::UpdateProjectCollaborator {
1380                            project_id: project.id.to_proto(),
1381                            old_peer_id: Some(project.old_connection_id.into()),
1382                            new_peer_id: Some(session.connection_id.into()),
1383                        },
1384                    )
1385                    .trace_err();
1386            }
1387
1388            broadcast(
1389                Some(session.connection_id),
1390                project
1391                    .collaborators
1392                    .iter()
1393                    .map(|collaborator| collaborator.connection_id),
1394                |connection_id| {
1395                    session.peer.forward_send(
1396                        session.connection_id,
1397                        connection_id,
1398                        proto::UpdateProject {
1399                            project_id: project.id.to_proto(),
1400                            worktrees: project.worktrees.clone(),
1401                        },
1402                    )
1403                },
1404            );
1405        }
1406
1407        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409        let rejoined_room = rejoined_room.into_inner();
1410
1411        room = rejoined_room.room;
1412        channel = rejoined_room.channel;
1413    }
1414
1415    if let Some(channel) = channel {
1416        channel_updated(
1417            &channel,
1418            &room,
1419            &session.peer,
1420            &*session.connection_pool().await,
1421        );
1422    }
1423
1424    update_user_contacts(session.user_id(), &session).await?;
1425    Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429    rejoined_projects: &mut Vec<RejoinedProject>,
1430    session: &Session,
1431) -> Result<()> {
1432    for project in rejoined_projects.iter() {
1433        for collaborator in &project.collaborators {
1434            session
1435                .peer
1436                .send(
1437                    collaborator.connection_id,
1438                    proto::UpdateProjectCollaborator {
1439                        project_id: project.id.to_proto(),
1440                        old_peer_id: Some(project.old_connection_id.into()),
1441                        new_peer_id: Some(session.connection_id.into()),
1442                    },
1443                )
1444                .trace_err();
1445        }
1446    }
1447
1448    for project in rejoined_projects {
1449        for worktree in mem::take(&mut project.worktrees) {
1450            // Stream this worktree's entries.
1451            let message = proto::UpdateWorktree {
1452                project_id: project.id.to_proto(),
1453                worktree_id: worktree.id,
1454                abs_path: worktree.abs_path.clone(),
1455                root_name: worktree.root_name,
1456                updated_entries: worktree.updated_entries,
1457                removed_entries: worktree.removed_entries,
1458                scan_id: worktree.scan_id,
1459                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460                updated_repositories: worktree.updated_repositories,
1461                removed_repositories: worktree.removed_repositories,
1462            };
1463            for update in proto::split_worktree_update(message) {
1464                session.peer.send(session.connection_id, update)?;
1465            }
1466
1467            // Stream this worktree's diagnostics.
1468            for summary in worktree.diagnostic_summaries {
1469                session.peer.send(
1470                    session.connection_id,
1471                    proto::UpdateDiagnosticSummary {
1472                        project_id: project.id.to_proto(),
1473                        worktree_id: worktree.id,
1474                        summary: Some(summary),
1475                    },
1476                )?;
1477            }
1478
1479            for settings_file in worktree.settings_files {
1480                session.peer.send(
1481                    session.connection_id,
1482                    proto::UpdateWorktreeSettings {
1483                        project_id: project.id.to_proto(),
1484                        worktree_id: worktree.id,
1485                        path: settings_file.path,
1486                        content: Some(settings_file.content),
1487                        kind: Some(settings_file.kind.to_proto().into()),
1488                    },
1489                )?;
1490            }
1491        }
1492
1493        for repository in mem::take(&mut project.updated_repositories) {
1494            for update in split_repository_update(repository) {
1495                session.peer.send(session.connection_id, update)?;
1496            }
1497        }
1498
1499        for id in mem::take(&mut project.removed_repositories) {
1500            session.peer.send(
1501                session.connection_id,
1502                proto::RemoveRepository {
1503                    project_id: project.id.to_proto(),
1504                    id,
1505                },
1506            )?;
1507        }
1508    }
1509
1510    Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515    _: proto::LeaveRoom,
1516    response: Response<proto::LeaveRoom>,
1517    session: Session,
1518) -> Result<()> {
1519    leave_room_for_session(&session, session.connection_id).await?;
1520    response.send(proto::Ack {})?;
1521    Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526    request: proto::SetRoomParticipantRole,
1527    response: Response<proto::SetRoomParticipantRole>,
1528    session: Session,
1529) -> Result<()> {
1530    let user_id = UserId::from_proto(request.user_id);
1531    let role = ChannelRole::from(request.role());
1532
1533    let (livekit_room, can_publish) = {
1534        let room = session
1535            .db()
1536            .await
1537            .set_room_participant_role(
1538                session.user_id(),
1539                RoomId::from_proto(request.room_id),
1540                user_id,
1541                role,
1542            )
1543            .await?;
1544
1545        let livekit_room = room.livekit_room.clone();
1546        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547        room_updated(&room, &session.peer);
1548        (livekit_room, can_publish)
1549    };
1550
1551    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552        live_kit
1553            .update_participant(
1554                livekit_room.clone(),
1555                request.user_id.to_string(),
1556                livekit_api::proto::ParticipantPermission {
1557                    can_subscribe: true,
1558                    can_publish,
1559                    can_publish_data: can_publish,
1560                    hidden: false,
1561                    recorder: false,
1562                },
1563            )
1564            .await
1565            .trace_err();
1566    }
1567
1568    response.send(proto::Ack {})?;
1569    Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574    request: proto::Call,
1575    response: Response<proto::Call>,
1576    session: Session,
1577) -> Result<()> {
1578    let room_id = RoomId::from_proto(request.room_id);
1579    let calling_user_id = session.user_id();
1580    let calling_connection_id = session.connection_id;
1581    let called_user_id = UserId::from_proto(request.called_user_id);
1582    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583    if !session
1584        .db()
1585        .await
1586        .has_contact(calling_user_id, called_user_id)
1587        .await?
1588    {
1589        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590    }
1591
1592    let incoming_call = {
1593        let (room, incoming_call) = &mut *session
1594            .db()
1595            .await
1596            .call(
1597                room_id,
1598                calling_user_id,
1599                calling_connection_id,
1600                called_user_id,
1601                initial_project_id,
1602            )
1603            .await?;
1604        room_updated(room, &session.peer);
1605        mem::take(incoming_call)
1606    };
1607    update_user_contacts(called_user_id, &session).await?;
1608
1609    let mut calls = session
1610        .connection_pool()
1611        .await
1612        .user_connection_ids(called_user_id)
1613        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614        .collect::<FuturesUnordered<_>>();
1615
1616    while let Some(call_response) = calls.next().await {
1617        match call_response.as_ref() {
1618            Ok(_) => {
1619                response.send(proto::Ack {})?;
1620                return Ok(());
1621            }
1622            Err(_) => {
1623                call_response.trace_err();
1624            }
1625        }
1626    }
1627
1628    {
1629        let room = session
1630            .db()
1631            .await
1632            .call_failed(room_id, called_user_id)
1633            .await?;
1634        room_updated(&room, &session.peer);
1635    }
1636    update_user_contacts(called_user_id, &session).await?;
1637
1638    Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643    request: proto::CancelCall,
1644    response: Response<proto::CancelCall>,
1645    session: Session,
1646) -> Result<()> {
1647    let called_user_id = UserId::from_proto(request.called_user_id);
1648    let room_id = RoomId::from_proto(request.room_id);
1649    {
1650        let room = session
1651            .db()
1652            .await
1653            .cancel_call(room_id, session.connection_id, called_user_id)
1654            .await?;
1655        room_updated(&room, &session.peer);
1656    }
1657
1658    for connection_id in session
1659        .connection_pool()
1660        .await
1661        .user_connection_ids(called_user_id)
1662    {
1663        session
1664            .peer
1665            .send(
1666                connection_id,
1667                proto::CallCanceled {
1668                    room_id: room_id.to_proto(),
1669                },
1670            )
1671            .trace_err();
1672    }
1673    response.send(proto::Ack {})?;
1674
1675    update_user_contacts(called_user_id, &session).await?;
1676    Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681    let room_id = RoomId::from_proto(message.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .decline_call(Some(room_id), session.user_id())
1687            .await?
1688            .ok_or_else(|| anyhow!("failed to decline call"))?;
1689        room_updated(&room, &session.peer);
1690    }
1691
1692    for connection_id in session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(session.user_id())
1696    {
1697        session
1698            .peer
1699            .send(
1700                connection_id,
1701                proto::CallCanceled {
1702                    room_id: room_id.to_proto(),
1703                },
1704            )
1705            .trace_err();
1706    }
1707    update_user_contacts(session.user_id(), &session).await?;
1708    Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713    request: proto::UpdateParticipantLocation,
1714    response: Response<proto::UpdateParticipantLocation>,
1715    session: Session,
1716) -> Result<()> {
1717    let room_id = RoomId::from_proto(request.room_id);
1718    let location = request
1719        .location
1720        .ok_or_else(|| anyhow!("invalid location"))?;
1721
1722    let db = session.db().await;
1723    let room = db
1724        .update_room_participant_location(room_id, session.connection_id, location)
1725        .await?;
1726
1727    room_updated(&room, &session.peer);
1728    response.send(proto::Ack {})?;
1729    Ok(())
1730}
1731
1732/// Share a project into the room.
1733async fn share_project(
1734    request: proto::ShareProject,
1735    response: Response<proto::ShareProject>,
1736    session: Session,
1737) -> Result<()> {
1738    let (project_id, room) = &*session
1739        .db()
1740        .await
1741        .share_project(
1742            RoomId::from_proto(request.room_id),
1743            session.connection_id,
1744            &request.worktrees,
1745            request.is_ssh_project,
1746        )
1747        .await?;
1748    response.send(proto::ShareProjectResponse {
1749        project_id: project_id.to_proto(),
1750    })?;
1751    room_updated(room, &session.peer);
1752
1753    Ok(())
1754}
1755
1756/// Unshare a project from the room.
1757async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1758    let project_id = ProjectId::from_proto(message.project_id);
1759    unshare_project_internal(project_id, session.connection_id, &session).await
1760}
1761
1762async fn unshare_project_internal(
1763    project_id: ProjectId,
1764    connection_id: ConnectionId,
1765    session: &Session,
1766) -> Result<()> {
1767    let delete = {
1768        let room_guard = session
1769            .db()
1770            .await
1771            .unshare_project(project_id, connection_id)
1772            .await?;
1773
1774        let (delete, room, guest_connection_ids) = &*room_guard;
1775
1776        let message = proto::UnshareProject {
1777            project_id: project_id.to_proto(),
1778        };
1779
1780        broadcast(
1781            Some(connection_id),
1782            guest_connection_ids.iter().copied(),
1783            |conn_id| session.peer.send(conn_id, message.clone()),
1784        );
1785        if let Some(room) = room {
1786            room_updated(room, &session.peer);
1787        }
1788
1789        *delete
1790    };
1791
1792    if delete {
1793        let db = session.db().await;
1794        db.delete_project(project_id).await?;
1795    }
1796
1797    Ok(())
1798}
1799
1800/// Join someone elses shared project.
1801async fn join_project(
1802    request: proto::JoinProject,
1803    response: Response<proto::JoinProject>,
1804    session: Session,
1805) -> Result<()> {
1806    let project_id = ProjectId::from_proto(request.project_id);
1807
1808    tracing::info!(%project_id, "join project");
1809
1810    let db = session.db().await;
1811    let (project, replica_id) = &mut *db
1812        .join_project(project_id, session.connection_id, session.user_id())
1813        .await?;
1814    drop(db);
1815    tracing::info!(%project_id, "join remote project");
1816    join_project_internal(response, session, project, replica_id)
1817}
1818
1819trait JoinProjectInternalResponse {
1820    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1821}
1822impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1823    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1824        Response::<proto::JoinProject>::send(self, result)
1825    }
1826}
1827
1828fn join_project_internal(
1829    response: impl JoinProjectInternalResponse,
1830    session: Session,
1831    project: &mut Project,
1832    replica_id: &ReplicaId,
1833) -> Result<()> {
1834    let collaborators = project
1835        .collaborators
1836        .iter()
1837        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1838        .map(|collaborator| collaborator.to_proto())
1839        .collect::<Vec<_>>();
1840    let project_id = project.id;
1841    let guest_user_id = session.user_id();
1842
1843    let worktrees = project
1844        .worktrees
1845        .iter()
1846        .map(|(id, worktree)| proto::WorktreeMetadata {
1847            id: *id,
1848            root_name: worktree.root_name.clone(),
1849            visible: worktree.visible,
1850            abs_path: worktree.abs_path.clone(),
1851        })
1852        .collect::<Vec<_>>();
1853
1854    let add_project_collaborator = proto::AddProjectCollaborator {
1855        project_id: project_id.to_proto(),
1856        collaborator: Some(proto::Collaborator {
1857            peer_id: Some(session.connection_id.into()),
1858            replica_id: replica_id.0 as u32,
1859            user_id: guest_user_id.to_proto(),
1860            is_host: false,
1861        }),
1862    };
1863
1864    for collaborator in &collaborators {
1865        session
1866            .peer
1867            .send(
1868                collaborator.peer_id.unwrap().into(),
1869                add_project_collaborator.clone(),
1870            )
1871            .trace_err();
1872    }
1873
1874    // First, we send the metadata associated with each worktree.
1875    response.send(proto::JoinProjectResponse {
1876        project_id: project.id.0 as u64,
1877        worktrees: worktrees.clone(),
1878        replica_id: replica_id.0 as u32,
1879        collaborators: collaborators.clone(),
1880        language_servers: project.language_servers.clone(),
1881        role: project.role.into(),
1882    })?;
1883
1884    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885        // Stream this worktree's entries.
1886        let message = proto::UpdateWorktree {
1887            project_id: project_id.to_proto(),
1888            worktree_id,
1889            abs_path: worktree.abs_path.clone(),
1890            root_name: worktree.root_name,
1891            updated_entries: worktree.entries,
1892            removed_entries: Default::default(),
1893            scan_id: worktree.scan_id,
1894            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1895            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1896            removed_repositories: Default::default(),
1897        };
1898        for update in proto::split_worktree_update(message) {
1899            session.peer.send(session.connection_id, update.clone())?;
1900        }
1901
1902        // Stream this worktree's diagnostics.
1903        for summary in worktree.diagnostic_summaries {
1904            session.peer.send(
1905                session.connection_id,
1906                proto::UpdateDiagnosticSummary {
1907                    project_id: project_id.to_proto(),
1908                    worktree_id: worktree.id,
1909                    summary: Some(summary),
1910                },
1911            )?;
1912        }
1913
1914        for settings_file in worktree.settings_files {
1915            session.peer.send(
1916                session.connection_id,
1917                proto::UpdateWorktreeSettings {
1918                    project_id: project_id.to_proto(),
1919                    worktree_id: worktree.id,
1920                    path: settings_file.path,
1921                    content: Some(settings_file.content),
1922                    kind: Some(settings_file.kind.to_proto() as i32),
1923                },
1924            )?;
1925        }
1926    }
1927
1928    for repository in mem::take(&mut project.repositories) {
1929        for update in split_repository_update(repository) {
1930            session.peer.send(session.connection_id, update)?;
1931        }
1932    }
1933
1934    for language_server in &project.language_servers {
1935        session.peer.send(
1936            session.connection_id,
1937            proto::UpdateLanguageServer {
1938                project_id: project_id.to_proto(),
1939                language_server_id: language_server.id,
1940                variant: Some(
1941                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1942                        proto::LspDiskBasedDiagnosticsUpdated {},
1943                    ),
1944                ),
1945            },
1946        )?;
1947    }
1948
1949    Ok(())
1950}
1951
1952/// Leave someone elses shared project.
1953async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1954    let sender_id = session.connection_id;
1955    let project_id = ProjectId::from_proto(request.project_id);
1956    let db = session.db().await;
1957
1958    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1959    tracing::info!(
1960        %project_id,
1961        "leave project"
1962    );
1963
1964    project_left(project, &session);
1965    if let Some(room) = room {
1966        room_updated(room, &session.peer);
1967    }
1968
1969    Ok(())
1970}
1971
1972/// Updates other participants with changes to the project
1973async fn update_project(
1974    request: proto::UpdateProject,
1975    response: Response<proto::UpdateProject>,
1976    session: Session,
1977) -> Result<()> {
1978    let project_id = ProjectId::from_proto(request.project_id);
1979    let (room, guest_connection_ids) = &*session
1980        .db()
1981        .await
1982        .update_project(project_id, session.connection_id, &request.worktrees)
1983        .await?;
1984    broadcast(
1985        Some(session.connection_id),
1986        guest_connection_ids.iter().copied(),
1987        |connection_id| {
1988            session
1989                .peer
1990                .forward_send(session.connection_id, connection_id, request.clone())
1991        },
1992    );
1993    if let Some(room) = room {
1994        room_updated(room, &session.peer);
1995    }
1996    response.send(proto::Ack {})?;
1997
1998    Ok(())
1999}
2000
2001/// Updates other participants with changes to the worktree
2002async fn update_worktree(
2003    request: proto::UpdateWorktree,
2004    response: Response<proto::UpdateWorktree>,
2005    session: Session,
2006) -> Result<()> {
2007    let guest_connection_ids = session
2008        .db()
2009        .await
2010        .update_worktree(&request, session.connection_id)
2011        .await?;
2012
2013    broadcast(
2014        Some(session.connection_id),
2015        guest_connection_ids.iter().copied(),
2016        |connection_id| {
2017            session
2018                .peer
2019                .forward_send(session.connection_id, connection_id, request.clone())
2020        },
2021    );
2022    response.send(proto::Ack {})?;
2023    Ok(())
2024}
2025
2026async fn update_repository(
2027    request: proto::UpdateRepository,
2028    response: Response<proto::UpdateRepository>,
2029    session: Session,
2030) -> Result<()> {
2031    let guest_connection_ids = session
2032        .db()
2033        .await
2034        .update_repository(&request, session.connection_id)
2035        .await?;
2036
2037    broadcast(
2038        Some(session.connection_id),
2039        guest_connection_ids.iter().copied(),
2040        |connection_id| {
2041            session
2042                .peer
2043                .forward_send(session.connection_id, connection_id, request.clone())
2044        },
2045    );
2046    response.send(proto::Ack {})?;
2047    Ok(())
2048}
2049
2050async fn remove_repository(
2051    request: proto::RemoveRepository,
2052    response: Response<proto::RemoveRepository>,
2053    session: Session,
2054) -> Result<()> {
2055    let guest_connection_ids = session
2056        .db()
2057        .await
2058        .remove_repository(&request, session.connection_id)
2059        .await?;
2060
2061    broadcast(
2062        Some(session.connection_id),
2063        guest_connection_ids.iter().copied(),
2064        |connection_id| {
2065            session
2066                .peer
2067                .forward_send(session.connection_id, connection_id, request.clone())
2068        },
2069    );
2070    response.send(proto::Ack {})?;
2071    Ok(())
2072}
2073
2074/// Updates other participants with changes to the diagnostics
2075async fn update_diagnostic_summary(
2076    message: proto::UpdateDiagnosticSummary,
2077    session: Session,
2078) -> Result<()> {
2079    let guest_connection_ids = session
2080        .db()
2081        .await
2082        .update_diagnostic_summary(&message, session.connection_id)
2083        .await?;
2084
2085    broadcast(
2086        Some(session.connection_id),
2087        guest_connection_ids.iter().copied(),
2088        |connection_id| {
2089            session
2090                .peer
2091                .forward_send(session.connection_id, connection_id, message.clone())
2092        },
2093    );
2094
2095    Ok(())
2096}
2097
2098/// Updates other participants with changes to the worktree settings
2099async fn update_worktree_settings(
2100    message: proto::UpdateWorktreeSettings,
2101    session: Session,
2102) -> Result<()> {
2103    let guest_connection_ids = session
2104        .db()
2105        .await
2106        .update_worktree_settings(&message, session.connection_id)
2107        .await?;
2108
2109    broadcast(
2110        Some(session.connection_id),
2111        guest_connection_ids.iter().copied(),
2112        |connection_id| {
2113            session
2114                .peer
2115                .forward_send(session.connection_id, connection_id, message.clone())
2116        },
2117    );
2118
2119    Ok(())
2120}
2121
2122/// Notify other participants that a language server has started.
2123async fn start_language_server(
2124    request: proto::StartLanguageServer,
2125    session: Session,
2126) -> Result<()> {
2127    let guest_connection_ids = session
2128        .db()
2129        .await
2130        .start_language_server(&request, session.connection_id)
2131        .await?;
2132
2133    broadcast(
2134        Some(session.connection_id),
2135        guest_connection_ids.iter().copied(),
2136        |connection_id| {
2137            session
2138                .peer
2139                .forward_send(session.connection_id, connection_id, request.clone())
2140        },
2141    );
2142    Ok(())
2143}
2144
2145/// Notify other participants that a language server has changed.
2146async fn update_language_server(
2147    request: proto::UpdateLanguageServer,
2148    session: Session,
2149) -> Result<()> {
2150    let project_id = ProjectId::from_proto(request.project_id);
2151    let project_connection_ids = session
2152        .db()
2153        .await
2154        .project_connection_ids(project_id, session.connection_id, true)
2155        .await?;
2156    broadcast(
2157        Some(session.connection_id),
2158        project_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, request.clone())
2163        },
2164    );
2165    Ok(())
2166}
2167
2168/// forward a project request to the host. These requests should be read only
2169/// as guests are allowed to send them.
2170async fn forward_read_only_project_request<T>(
2171    request: T,
2172    response: Response<T>,
2173    session: Session,
2174) -> Result<()>
2175where
2176    T: EntityMessage + RequestMessage,
2177{
2178    let project_id = ProjectId::from_proto(request.remote_entity_id());
2179    let host_connection_id = session
2180        .db()
2181        .await
2182        .host_for_read_only_project_request(project_id, session.connection_id)
2183        .await?;
2184    let payload = session
2185        .peer
2186        .forward_request(session.connection_id, host_connection_id, request)
2187        .await?;
2188    response.send(payload)?;
2189    Ok(())
2190}
2191
2192async fn forward_find_search_candidates_request(
2193    request: proto::FindSearchCandidates,
2194    response: Response<proto::FindSearchCandidates>,
2195    session: Session,
2196) -> Result<()> {
2197    let project_id = ProjectId::from_proto(request.remote_entity_id());
2198    let host_connection_id = session
2199        .db()
2200        .await
2201        .host_for_read_only_project_request(project_id, session.connection_id)
2202        .await?;
2203    let payload = session
2204        .peer
2205        .forward_request(session.connection_id, host_connection_id, request)
2206        .await?;
2207    response.send(payload)?;
2208    Ok(())
2209}
2210
2211/// forward a project request to the host. These requests are disallowed
2212/// for guests.
2213async fn forward_mutating_project_request<T>(
2214    request: T,
2215    response: Response<T>,
2216    session: Session,
2217) -> Result<()>
2218where
2219    T: EntityMessage + RequestMessage,
2220{
2221    let project_id = ProjectId::from_proto(request.remote_entity_id());
2222
2223    let host_connection_id = session
2224        .db()
2225        .await
2226        .host_for_mutating_project_request(project_id, session.connection_id)
2227        .await?;
2228    let payload = session
2229        .peer
2230        .forward_request(session.connection_id, host_connection_id, request)
2231        .await?;
2232    response.send(payload)?;
2233    Ok(())
2234}
2235
2236/// Notify other participants that a new buffer has been created
2237async fn create_buffer_for_peer(
2238    request: proto::CreateBufferForPeer,
2239    session: Session,
2240) -> Result<()> {
2241    session
2242        .db()
2243        .await
2244        .check_user_is_project_host(
2245            ProjectId::from_proto(request.project_id),
2246            session.connection_id,
2247        )
2248        .await?;
2249    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2250    session
2251        .peer
2252        .forward_send(session.connection_id, peer_id.into(), request)?;
2253    Ok(())
2254}
2255
2256/// Notify other participants that a buffer has been updated. This is
2257/// allowed for guests as long as the update is limited to selections.
2258async fn update_buffer(
2259    request: proto::UpdateBuffer,
2260    response: Response<proto::UpdateBuffer>,
2261    session: Session,
2262) -> Result<()> {
2263    let project_id = ProjectId::from_proto(request.project_id);
2264    let mut capability = Capability::ReadOnly;
2265
2266    for op in request.operations.iter() {
2267        match op.variant {
2268            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2269            Some(_) => capability = Capability::ReadWrite,
2270        }
2271    }
2272
2273    let host = {
2274        let guard = session
2275            .db()
2276            .await
2277            .connections_for_buffer_update(project_id, session.connection_id, capability)
2278            .await?;
2279
2280        let (host, guests) = &*guard;
2281
2282        broadcast(
2283            Some(session.connection_id),
2284            guests.clone(),
2285            |connection_id| {
2286                session
2287                    .peer
2288                    .forward_send(session.connection_id, connection_id, request.clone())
2289            },
2290        );
2291
2292        *host
2293    };
2294
2295    if host != session.connection_id {
2296        session
2297            .peer
2298            .forward_request(session.connection_id, host, request.clone())
2299            .await?;
2300    }
2301
2302    response.send(proto::Ack {})?;
2303    Ok(())
2304}
2305
2306async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2307    let project_id = ProjectId::from_proto(message.project_id);
2308
2309    let operation = message.operation.as_ref().context("invalid operation")?;
2310    let capability = match operation.variant.as_ref() {
2311        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2312            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2313                match buffer_op.variant {
2314                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2315                        Capability::ReadOnly
2316                    }
2317                    _ => Capability::ReadWrite,
2318                }
2319            } else {
2320                Capability::ReadWrite
2321            }
2322        }
2323        Some(_) => Capability::ReadWrite,
2324        None => Capability::ReadOnly,
2325    };
2326
2327    let guard = session
2328        .db()
2329        .await
2330        .connections_for_buffer_update(project_id, session.connection_id, capability)
2331        .await?;
2332
2333    let (host, guests) = &*guard;
2334
2335    broadcast(
2336        Some(session.connection_id),
2337        guests.iter().chain([host]).copied(),
2338        |connection_id| {
2339            session
2340                .peer
2341                .forward_send(session.connection_id, connection_id, message.clone())
2342        },
2343    );
2344
2345    Ok(())
2346}
2347
2348/// Notify other participants that a project has been updated.
2349async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2350    request: T,
2351    session: Session,
2352) -> Result<()> {
2353    let project_id = ProjectId::from_proto(request.remote_entity_id());
2354    let project_connection_ids = session
2355        .db()
2356        .await
2357        .project_connection_ids(project_id, session.connection_id, false)
2358        .await?;
2359
2360    broadcast(
2361        Some(session.connection_id),
2362        project_connection_ids.iter().copied(),
2363        |connection_id| {
2364            session
2365                .peer
2366                .forward_send(session.connection_id, connection_id, request.clone())
2367        },
2368    );
2369    Ok(())
2370}
2371
2372/// Start following another user in a call.
2373async fn follow(
2374    request: proto::Follow,
2375    response: Response<proto::Follow>,
2376    session: Session,
2377) -> Result<()> {
2378    let room_id = RoomId::from_proto(request.room_id);
2379    let project_id = request.project_id.map(ProjectId::from_proto);
2380    let leader_id = request
2381        .leader_id
2382        .ok_or_else(|| anyhow!("invalid leader id"))?
2383        .into();
2384    let follower_id = session.connection_id;
2385
2386    session
2387        .db()
2388        .await
2389        .check_room_participants(room_id, leader_id, session.connection_id)
2390        .await?;
2391
2392    let response_payload = session
2393        .peer
2394        .forward_request(session.connection_id, leader_id, request)
2395        .await?;
2396    response.send(response_payload)?;
2397
2398    if let Some(project_id) = project_id {
2399        let room = session
2400            .db()
2401            .await
2402            .follow(room_id, project_id, leader_id, follower_id)
2403            .await?;
2404        room_updated(&room, &session.peer);
2405    }
2406
2407    Ok(())
2408}
2409
2410/// Stop following another user in a call.
2411async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2412    let room_id = RoomId::from_proto(request.room_id);
2413    let project_id = request.project_id.map(ProjectId::from_proto);
2414    let leader_id = request
2415        .leader_id
2416        .ok_or_else(|| anyhow!("invalid leader id"))?
2417        .into();
2418    let follower_id = session.connection_id;
2419
2420    session
2421        .db()
2422        .await
2423        .check_room_participants(room_id, leader_id, session.connection_id)
2424        .await?;
2425
2426    session
2427        .peer
2428        .forward_send(session.connection_id, leader_id, request)?;
2429
2430    if let Some(project_id) = project_id {
2431        let room = session
2432            .db()
2433            .await
2434            .unfollow(room_id, project_id, leader_id, follower_id)
2435            .await?;
2436        room_updated(&room, &session.peer);
2437    }
2438
2439    Ok(())
2440}
2441
2442/// Notify everyone following you of your current location.
2443async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2444    let room_id = RoomId::from_proto(request.room_id);
2445    let database = session.db.lock().await;
2446
2447    let connection_ids = if let Some(project_id) = request.project_id {
2448        let project_id = ProjectId::from_proto(project_id);
2449        database
2450            .project_connection_ids(project_id, session.connection_id, true)
2451            .await?
2452    } else {
2453        database
2454            .room_connection_ids(room_id, session.connection_id)
2455            .await?
2456    };
2457
2458    // For now, don't send view update messages back to that view's current leader.
2459    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2460        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2461        _ => None,
2462    });
2463
2464    for connection_id in connection_ids.iter().cloned() {
2465        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2466            session
2467                .peer
2468                .forward_send(session.connection_id, connection_id, request.clone())?;
2469        }
2470    }
2471    Ok(())
2472}
2473
2474/// Get public data about users.
2475async fn get_users(
2476    request: proto::GetUsers,
2477    response: Response<proto::GetUsers>,
2478    session: Session,
2479) -> Result<()> {
2480    let user_ids = request
2481        .user_ids
2482        .into_iter()
2483        .map(UserId::from_proto)
2484        .collect();
2485    let users = session
2486        .db()
2487        .await
2488        .get_users_by_ids(user_ids)
2489        .await?
2490        .into_iter()
2491        .map(|user| proto::User {
2492            id: user.id.to_proto(),
2493            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2494            github_login: user.github_login,
2495            email: user.email_address,
2496            name: user.name,
2497        })
2498        .collect();
2499    response.send(proto::UsersResponse { users })?;
2500    Ok(())
2501}
2502
2503/// Search for users (to invite) buy Github login
2504async fn fuzzy_search_users(
2505    request: proto::FuzzySearchUsers,
2506    response: Response<proto::FuzzySearchUsers>,
2507    session: Session,
2508) -> Result<()> {
2509    let query = request.query;
2510    let users = match query.len() {
2511        0 => vec![],
2512        1 | 2 => session
2513            .db()
2514            .await
2515            .get_user_by_github_login(&query)
2516            .await?
2517            .into_iter()
2518            .collect(),
2519        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2520    };
2521    let users = users
2522        .into_iter()
2523        .filter(|user| user.id != session.user_id())
2524        .map(|user| proto::User {
2525            id: user.id.to_proto(),
2526            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2527            github_login: user.github_login,
2528            name: user.name,
2529            email: user.email_address,
2530        })
2531        .collect();
2532    response.send(proto::UsersResponse { users })?;
2533    Ok(())
2534}
2535
2536/// Send a contact request to another user.
2537async fn request_contact(
2538    request: proto::RequestContact,
2539    response: Response<proto::RequestContact>,
2540    session: Session,
2541) -> Result<()> {
2542    let requester_id = session.user_id();
2543    let responder_id = UserId::from_proto(request.responder_id);
2544    if requester_id == responder_id {
2545        return Err(anyhow!("cannot add yourself as a contact"))?;
2546    }
2547
2548    let notifications = session
2549        .db()
2550        .await
2551        .send_contact_request(requester_id, responder_id)
2552        .await?;
2553
2554    // Update outgoing contact requests of requester
2555    let mut update = proto::UpdateContacts::default();
2556    update.outgoing_requests.push(responder_id.to_proto());
2557    for connection_id in session
2558        .connection_pool()
2559        .await
2560        .user_connection_ids(requester_id)
2561    {
2562        session.peer.send(connection_id, update.clone())?;
2563    }
2564
2565    // Update incoming contact requests of responder
2566    let mut update = proto::UpdateContacts::default();
2567    update
2568        .incoming_requests
2569        .push(proto::IncomingContactRequest {
2570            requester_id: requester_id.to_proto(),
2571        });
2572    let connection_pool = session.connection_pool().await;
2573    for connection_id in connection_pool.user_connection_ids(responder_id) {
2574        session.peer.send(connection_id, update.clone())?;
2575    }
2576
2577    send_notifications(&connection_pool, &session.peer, notifications);
2578
2579    response.send(proto::Ack {})?;
2580    Ok(())
2581}
2582
2583/// Accept or decline a contact request
2584async fn respond_to_contact_request(
2585    request: proto::RespondToContactRequest,
2586    response: Response<proto::RespondToContactRequest>,
2587    session: Session,
2588) -> Result<()> {
2589    let responder_id = session.user_id();
2590    let requester_id = UserId::from_proto(request.requester_id);
2591    let db = session.db().await;
2592    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2593        db.dismiss_contact_notification(responder_id, requester_id)
2594            .await?;
2595    } else {
2596        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2597
2598        let notifications = db
2599            .respond_to_contact_request(responder_id, requester_id, accept)
2600            .await?;
2601        let requester_busy = db.is_user_busy(requester_id).await?;
2602        let responder_busy = db.is_user_busy(responder_id).await?;
2603
2604        let pool = session.connection_pool().await;
2605        // Update responder with new contact
2606        let mut update = proto::UpdateContacts::default();
2607        if accept {
2608            update
2609                .contacts
2610                .push(contact_for_user(requester_id, requester_busy, &pool));
2611        }
2612        update
2613            .remove_incoming_requests
2614            .push(requester_id.to_proto());
2615        for connection_id in pool.user_connection_ids(responder_id) {
2616            session.peer.send(connection_id, update.clone())?;
2617        }
2618
2619        // Update requester with new contact
2620        let mut update = proto::UpdateContacts::default();
2621        if accept {
2622            update
2623                .contacts
2624                .push(contact_for_user(responder_id, responder_busy, &pool));
2625        }
2626        update
2627            .remove_outgoing_requests
2628            .push(responder_id.to_proto());
2629
2630        for connection_id in pool.user_connection_ids(requester_id) {
2631            session.peer.send(connection_id, update.clone())?;
2632        }
2633
2634        send_notifications(&pool, &session.peer, notifications);
2635    }
2636
2637    response.send(proto::Ack {})?;
2638    Ok(())
2639}
2640
2641/// Remove a contact.
2642async fn remove_contact(
2643    request: proto::RemoveContact,
2644    response: Response<proto::RemoveContact>,
2645    session: Session,
2646) -> Result<()> {
2647    let requester_id = session.user_id();
2648    let responder_id = UserId::from_proto(request.user_id);
2649    let db = session.db().await;
2650    let (contact_accepted, deleted_notification_id) =
2651        db.remove_contact(requester_id, responder_id).await?;
2652
2653    let pool = session.connection_pool().await;
2654    // Update outgoing contact requests of requester
2655    let mut update = proto::UpdateContacts::default();
2656    if contact_accepted {
2657        update.remove_contacts.push(responder_id.to_proto());
2658    } else {
2659        update
2660            .remove_outgoing_requests
2661            .push(responder_id.to_proto());
2662    }
2663    for connection_id in pool.user_connection_ids(requester_id) {
2664        session.peer.send(connection_id, update.clone())?;
2665    }
2666
2667    // Update incoming contact requests of responder
2668    let mut update = proto::UpdateContacts::default();
2669    if contact_accepted {
2670        update.remove_contacts.push(requester_id.to_proto());
2671    } else {
2672        update
2673            .remove_incoming_requests
2674            .push(requester_id.to_proto());
2675    }
2676    for connection_id in pool.user_connection_ids(responder_id) {
2677        session.peer.send(connection_id, update.clone())?;
2678        if let Some(notification_id) = deleted_notification_id {
2679            session.peer.send(
2680                connection_id,
2681                proto::DeleteNotification {
2682                    notification_id: notification_id.to_proto(),
2683                },
2684            )?;
2685        }
2686    }
2687
2688    response.send(proto::Ack {})?;
2689    Ok(())
2690}
2691
2692fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2693    version.0.minor() < 139
2694}
2695
2696async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2697    let plan = session.current_plan(&session.db().await).await?;
2698
2699    session
2700        .peer
2701        .send(
2702            session.connection_id,
2703            proto::UpdateUserPlan { plan: plan.into() },
2704        )
2705        .trace_err();
2706
2707    Ok(())
2708}
2709
2710async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2711    subscribe_user_to_channels(session.user_id(), &session).await?;
2712    Ok(())
2713}
2714
2715async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2716    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2717    let mut pool = session.connection_pool().await;
2718    for membership in &channels_for_user.channel_memberships {
2719        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2720    }
2721    session.peer.send(
2722        session.connection_id,
2723        build_update_user_channels(&channels_for_user),
2724    )?;
2725    session.peer.send(
2726        session.connection_id,
2727        build_channels_update(channels_for_user),
2728    )?;
2729    Ok(())
2730}
2731
2732/// Creates a new channel.
2733async fn create_channel(
2734    request: proto::CreateChannel,
2735    response: Response<proto::CreateChannel>,
2736    session: Session,
2737) -> Result<()> {
2738    let db = session.db().await;
2739
2740    let parent_id = request.parent_id.map(ChannelId::from_proto);
2741    let (channel, membership) = db
2742        .create_channel(&request.name, parent_id, session.user_id())
2743        .await?;
2744
2745    let root_id = channel.root_id();
2746    let channel = Channel::from_model(channel);
2747
2748    response.send(proto::CreateChannelResponse {
2749        channel: Some(channel.to_proto()),
2750        parent_id: request.parent_id,
2751    })?;
2752
2753    let mut connection_pool = session.connection_pool().await;
2754    if let Some(membership) = membership {
2755        connection_pool.subscribe_to_channel(
2756            membership.user_id,
2757            membership.channel_id,
2758            membership.role,
2759        );
2760        let update = proto::UpdateUserChannels {
2761            channel_memberships: vec![proto::ChannelMembership {
2762                channel_id: membership.channel_id.to_proto(),
2763                role: membership.role.into(),
2764            }],
2765            ..Default::default()
2766        };
2767        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2768            session.peer.send(connection_id, update.clone())?;
2769        }
2770    }
2771
2772    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2773        if !role.can_see_channel(channel.visibility) {
2774            continue;
2775        }
2776
2777        let update = proto::UpdateChannels {
2778            channels: vec![channel.to_proto()],
2779            ..Default::default()
2780        };
2781        session.peer.send(connection_id, update.clone())?;
2782    }
2783
2784    Ok(())
2785}
2786
2787/// Delete a channel
2788async fn delete_channel(
2789    request: proto::DeleteChannel,
2790    response: Response<proto::DeleteChannel>,
2791    session: Session,
2792) -> Result<()> {
2793    let db = session.db().await;
2794
2795    let channel_id = request.channel_id;
2796    let (root_channel, removed_channels) = db
2797        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2798        .await?;
2799    response.send(proto::Ack {})?;
2800
2801    // Notify members of removed channels
2802    let mut update = proto::UpdateChannels::default();
2803    update
2804        .delete_channels
2805        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2806
2807    let connection_pool = session.connection_pool().await;
2808    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2809        session.peer.send(connection_id, update.clone())?;
2810    }
2811
2812    Ok(())
2813}
2814
2815/// Invite someone to join a channel.
2816async fn invite_channel_member(
2817    request: proto::InviteChannelMember,
2818    response: Response<proto::InviteChannelMember>,
2819    session: Session,
2820) -> Result<()> {
2821    let db = session.db().await;
2822    let channel_id = ChannelId::from_proto(request.channel_id);
2823    let invitee_id = UserId::from_proto(request.user_id);
2824    let InviteMemberResult {
2825        channel,
2826        notifications,
2827    } = db
2828        .invite_channel_member(
2829            channel_id,
2830            invitee_id,
2831            session.user_id(),
2832            request.role().into(),
2833        )
2834        .await?;
2835
2836    let update = proto::UpdateChannels {
2837        channel_invitations: vec![channel.to_proto()],
2838        ..Default::default()
2839    };
2840
2841    let connection_pool = session.connection_pool().await;
2842    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2843        session.peer.send(connection_id, update.clone())?;
2844    }
2845
2846    send_notifications(&connection_pool, &session.peer, notifications);
2847
2848    response.send(proto::Ack {})?;
2849    Ok(())
2850}
2851
2852/// remove someone from a channel
2853async fn remove_channel_member(
2854    request: proto::RemoveChannelMember,
2855    response: Response<proto::RemoveChannelMember>,
2856    session: Session,
2857) -> Result<()> {
2858    let db = session.db().await;
2859    let channel_id = ChannelId::from_proto(request.channel_id);
2860    let member_id = UserId::from_proto(request.user_id);
2861
2862    let RemoveChannelMemberResult {
2863        membership_update,
2864        notification_id,
2865    } = db
2866        .remove_channel_member(channel_id, member_id, session.user_id())
2867        .await?;
2868
2869    let mut connection_pool = session.connection_pool().await;
2870    notify_membership_updated(
2871        &mut connection_pool,
2872        membership_update,
2873        member_id,
2874        &session.peer,
2875    );
2876    for connection_id in connection_pool.user_connection_ids(member_id) {
2877        if let Some(notification_id) = notification_id {
2878            session
2879                .peer
2880                .send(
2881                    connection_id,
2882                    proto::DeleteNotification {
2883                        notification_id: notification_id.to_proto(),
2884                    },
2885                )
2886                .trace_err();
2887        }
2888    }
2889
2890    response.send(proto::Ack {})?;
2891    Ok(())
2892}
2893
2894/// Toggle the channel between public and private.
2895/// Care is taken to maintain the invariant that public channels only descend from public channels,
2896/// (though members-only channels can appear at any point in the hierarchy).
2897async fn set_channel_visibility(
2898    request: proto::SetChannelVisibility,
2899    response: Response<proto::SetChannelVisibility>,
2900    session: Session,
2901) -> Result<()> {
2902    let db = session.db().await;
2903    let channel_id = ChannelId::from_proto(request.channel_id);
2904    let visibility = request.visibility().into();
2905
2906    let channel_model = db
2907        .set_channel_visibility(channel_id, visibility, session.user_id())
2908        .await?;
2909    let root_id = channel_model.root_id();
2910    let channel = Channel::from_model(channel_model);
2911
2912    let mut connection_pool = session.connection_pool().await;
2913    for (user_id, role) in connection_pool
2914        .channel_user_ids(root_id)
2915        .collect::<Vec<_>>()
2916        .into_iter()
2917    {
2918        let update = if role.can_see_channel(channel.visibility) {
2919            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2920            proto::UpdateChannels {
2921                channels: vec![channel.to_proto()],
2922                ..Default::default()
2923            }
2924        } else {
2925            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2926            proto::UpdateChannels {
2927                delete_channels: vec![channel.id.to_proto()],
2928                ..Default::default()
2929            }
2930        };
2931
2932        for connection_id in connection_pool.user_connection_ids(user_id) {
2933            session.peer.send(connection_id, update.clone())?;
2934        }
2935    }
2936
2937    response.send(proto::Ack {})?;
2938    Ok(())
2939}
2940
2941/// Alter the role for a user in the channel.
2942async fn set_channel_member_role(
2943    request: proto::SetChannelMemberRole,
2944    response: Response<proto::SetChannelMemberRole>,
2945    session: Session,
2946) -> Result<()> {
2947    let db = session.db().await;
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let member_id = UserId::from_proto(request.user_id);
2950    let result = db
2951        .set_channel_member_role(
2952            channel_id,
2953            session.user_id(),
2954            member_id,
2955            request.role().into(),
2956        )
2957        .await?;
2958
2959    match result {
2960        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2961            let mut connection_pool = session.connection_pool().await;
2962            notify_membership_updated(
2963                &mut connection_pool,
2964                membership_update,
2965                member_id,
2966                &session.peer,
2967            )
2968        }
2969        db::SetMemberRoleResult::InviteUpdated(channel) => {
2970            let update = proto::UpdateChannels {
2971                channel_invitations: vec![channel.to_proto()],
2972                ..Default::default()
2973            };
2974
2975            for connection_id in session
2976                .connection_pool()
2977                .await
2978                .user_connection_ids(member_id)
2979            {
2980                session.peer.send(connection_id, update.clone())?;
2981            }
2982        }
2983    }
2984
2985    response.send(proto::Ack {})?;
2986    Ok(())
2987}
2988
2989/// Change the name of a channel
2990async fn rename_channel(
2991    request: proto::RenameChannel,
2992    response: Response<proto::RenameChannel>,
2993    session: Session,
2994) -> Result<()> {
2995    let db = session.db().await;
2996    let channel_id = ChannelId::from_proto(request.channel_id);
2997    let channel_model = db
2998        .rename_channel(channel_id, session.user_id(), &request.name)
2999        .await?;
3000    let root_id = channel_model.root_id();
3001    let channel = Channel::from_model(channel_model);
3002
3003    response.send(proto::RenameChannelResponse {
3004        channel: Some(channel.to_proto()),
3005    })?;
3006
3007    let connection_pool = session.connection_pool().await;
3008    let update = proto::UpdateChannels {
3009        channels: vec![channel.to_proto()],
3010        ..Default::default()
3011    };
3012    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3013        if role.can_see_channel(channel.visibility) {
3014            session.peer.send(connection_id, update.clone())?;
3015        }
3016    }
3017
3018    Ok(())
3019}
3020
3021/// Move a channel to a new parent.
3022async fn move_channel(
3023    request: proto::MoveChannel,
3024    response: Response<proto::MoveChannel>,
3025    session: Session,
3026) -> Result<()> {
3027    let channel_id = ChannelId::from_proto(request.channel_id);
3028    let to = ChannelId::from_proto(request.to);
3029
3030    let (root_id, channels) = session
3031        .db()
3032        .await
3033        .move_channel(channel_id, to, session.user_id())
3034        .await?;
3035
3036    let connection_pool = session.connection_pool().await;
3037    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3038        let channels = channels
3039            .iter()
3040            .filter_map(|channel| {
3041                if role.can_see_channel(channel.visibility) {
3042                    Some(channel.to_proto())
3043                } else {
3044                    None
3045                }
3046            })
3047            .collect::<Vec<_>>();
3048        if channels.is_empty() {
3049            continue;
3050        }
3051
3052        let update = proto::UpdateChannels {
3053            channels,
3054            ..Default::default()
3055        };
3056
3057        session.peer.send(connection_id, update.clone())?;
3058    }
3059
3060    response.send(Ack {})?;
3061    Ok(())
3062}
3063
3064/// Get the list of channel members
3065async fn get_channel_members(
3066    request: proto::GetChannelMembers,
3067    response: Response<proto::GetChannelMembers>,
3068    session: Session,
3069) -> Result<()> {
3070    let db = session.db().await;
3071    let channel_id = ChannelId::from_proto(request.channel_id);
3072    let limit = if request.limit == 0 {
3073        u16::MAX as u64
3074    } else {
3075        request.limit
3076    };
3077    let (members, users) = db
3078        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3079        .await?;
3080    response.send(proto::GetChannelMembersResponse { members, users })?;
3081    Ok(())
3082}
3083
3084/// Accept or decline a channel invitation.
3085async fn respond_to_channel_invite(
3086    request: proto::RespondToChannelInvite,
3087    response: Response<proto::RespondToChannelInvite>,
3088    session: Session,
3089) -> Result<()> {
3090    let db = session.db().await;
3091    let channel_id = ChannelId::from_proto(request.channel_id);
3092    let RespondToChannelInvite {
3093        membership_update,
3094        notifications,
3095    } = db
3096        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3097        .await?;
3098
3099    let mut connection_pool = session.connection_pool().await;
3100    if let Some(membership_update) = membership_update {
3101        notify_membership_updated(
3102            &mut connection_pool,
3103            membership_update,
3104            session.user_id(),
3105            &session.peer,
3106        );
3107    } else {
3108        let update = proto::UpdateChannels {
3109            remove_channel_invitations: vec![channel_id.to_proto()],
3110            ..Default::default()
3111        };
3112
3113        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3114            session.peer.send(connection_id, update.clone())?;
3115        }
3116    };
3117
3118    send_notifications(&connection_pool, &session.peer, notifications);
3119
3120    response.send(proto::Ack {})?;
3121
3122    Ok(())
3123}
3124
3125/// Join the channels' room
3126async fn join_channel(
3127    request: proto::JoinChannel,
3128    response: Response<proto::JoinChannel>,
3129    session: Session,
3130) -> Result<()> {
3131    let channel_id = ChannelId::from_proto(request.channel_id);
3132    join_channel_internal(channel_id, Box::new(response), session).await
3133}
3134
3135trait JoinChannelInternalResponse {
3136    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3137}
3138impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3139    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3140        Response::<proto::JoinChannel>::send(self, result)
3141    }
3142}
3143impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3144    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3145        Response::<proto::JoinRoom>::send(self, result)
3146    }
3147}
3148
3149async fn join_channel_internal(
3150    channel_id: ChannelId,
3151    response: Box<impl JoinChannelInternalResponse>,
3152    session: Session,
3153) -> Result<()> {
3154    let joined_room = {
3155        let mut db = session.db().await;
3156        // If zed quits without leaving the room, and the user re-opens zed before the
3157        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3158        // room they were in.
3159        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3160            tracing::info!(
3161                stale_connection_id = %connection,
3162                "cleaning up stale connection",
3163            );
3164            drop(db);
3165            leave_room_for_session(&session, connection).await?;
3166            db = session.db().await;
3167        }
3168
3169        let (joined_room, membership_updated, role) = db
3170            .join_channel(channel_id, session.user_id(), session.connection_id)
3171            .await?;
3172
3173        let live_kit_connection_info =
3174            session
3175                .app_state
3176                .livekit_client
3177                .as_ref()
3178                .and_then(|live_kit| {
3179                    let (can_publish, token) = if role == ChannelRole::Guest {
3180                        (
3181                            false,
3182                            live_kit
3183                                .guest_token(
3184                                    &joined_room.room.livekit_room,
3185                                    &session.user_id().to_string(),
3186                                )
3187                                .trace_err()?,
3188                        )
3189                    } else {
3190                        (
3191                            true,
3192                            live_kit
3193                                .room_token(
3194                                    &joined_room.room.livekit_room,
3195                                    &session.user_id().to_string(),
3196                                )
3197                                .trace_err()?,
3198                        )
3199                    };
3200
3201                    Some(LiveKitConnectionInfo {
3202                        server_url: live_kit.url().into(),
3203                        token,
3204                        can_publish,
3205                    })
3206                });
3207
3208        response.send(proto::JoinRoomResponse {
3209            room: Some(joined_room.room.clone()),
3210            channel_id: joined_room
3211                .channel
3212                .as_ref()
3213                .map(|channel| channel.id.to_proto()),
3214            live_kit_connection_info,
3215        })?;
3216
3217        let mut connection_pool = session.connection_pool().await;
3218        if let Some(membership_updated) = membership_updated {
3219            notify_membership_updated(
3220                &mut connection_pool,
3221                membership_updated,
3222                session.user_id(),
3223                &session.peer,
3224            );
3225        }
3226
3227        room_updated(&joined_room.room, &session.peer);
3228
3229        joined_room
3230    };
3231
3232    channel_updated(
3233        &joined_room
3234            .channel
3235            .ok_or_else(|| anyhow!("channel not returned"))?,
3236        &joined_room.room,
3237        &session.peer,
3238        &*session.connection_pool().await,
3239    );
3240
3241    update_user_contacts(session.user_id(), &session).await?;
3242    Ok(())
3243}
3244
3245/// Start editing the channel notes
3246async fn join_channel_buffer(
3247    request: proto::JoinChannelBuffer,
3248    response: Response<proto::JoinChannelBuffer>,
3249    session: Session,
3250) -> Result<()> {
3251    let db = session.db().await;
3252    let channel_id = ChannelId::from_proto(request.channel_id);
3253
3254    let open_response = db
3255        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3256        .await?;
3257
3258    let collaborators = open_response.collaborators.clone();
3259    response.send(open_response)?;
3260
3261    let update = UpdateChannelBufferCollaborators {
3262        channel_id: channel_id.to_proto(),
3263        collaborators: collaborators.clone(),
3264    };
3265    channel_buffer_updated(
3266        session.connection_id,
3267        collaborators
3268            .iter()
3269            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3270        &update,
3271        &session.peer,
3272    );
3273
3274    Ok(())
3275}
3276
3277/// Edit the channel notes
3278async fn update_channel_buffer(
3279    request: proto::UpdateChannelBuffer,
3280    session: Session,
3281) -> Result<()> {
3282    let db = session.db().await;
3283    let channel_id = ChannelId::from_proto(request.channel_id);
3284
3285    let (collaborators, epoch, version) = db
3286        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3287        .await?;
3288
3289    channel_buffer_updated(
3290        session.connection_id,
3291        collaborators.clone(),
3292        &proto::UpdateChannelBuffer {
3293            channel_id: channel_id.to_proto(),
3294            operations: request.operations,
3295        },
3296        &session.peer,
3297    );
3298
3299    let pool = &*session.connection_pool().await;
3300
3301    let non_collaborators =
3302        pool.channel_connection_ids(channel_id)
3303            .filter_map(|(connection_id, _)| {
3304                if collaborators.contains(&connection_id) {
3305                    None
3306                } else {
3307                    Some(connection_id)
3308                }
3309            });
3310
3311    broadcast(None, non_collaborators, |peer_id| {
3312        session.peer.send(
3313            peer_id,
3314            proto::UpdateChannels {
3315                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3316                    channel_id: channel_id.to_proto(),
3317                    epoch: epoch as u64,
3318                    version: version.clone(),
3319                }],
3320                ..Default::default()
3321            },
3322        )
3323    });
3324
3325    Ok(())
3326}
3327
3328/// Rejoin the channel notes after a connection blip
3329async fn rejoin_channel_buffers(
3330    request: proto::RejoinChannelBuffers,
3331    response: Response<proto::RejoinChannelBuffers>,
3332    session: Session,
3333) -> Result<()> {
3334    let db = session.db().await;
3335    let buffers = db
3336        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3337        .await?;
3338
3339    for rejoined_buffer in &buffers {
3340        let collaborators_to_notify = rejoined_buffer
3341            .buffer
3342            .collaborators
3343            .iter()
3344            .filter_map(|c| Some(c.peer_id?.into()));
3345        channel_buffer_updated(
3346            session.connection_id,
3347            collaborators_to_notify,
3348            &proto::UpdateChannelBufferCollaborators {
3349                channel_id: rejoined_buffer.buffer.channel_id,
3350                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3351            },
3352            &session.peer,
3353        );
3354    }
3355
3356    response.send(proto::RejoinChannelBuffersResponse {
3357        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3358    })?;
3359
3360    Ok(())
3361}
3362
3363/// Stop editing the channel notes
3364async fn leave_channel_buffer(
3365    request: proto::LeaveChannelBuffer,
3366    response: Response<proto::LeaveChannelBuffer>,
3367    session: Session,
3368) -> Result<()> {
3369    let db = session.db().await;
3370    let channel_id = ChannelId::from_proto(request.channel_id);
3371
3372    let left_buffer = db
3373        .leave_channel_buffer(channel_id, session.connection_id)
3374        .await?;
3375
3376    response.send(Ack {})?;
3377
3378    channel_buffer_updated(
3379        session.connection_id,
3380        left_buffer.connections,
3381        &proto::UpdateChannelBufferCollaborators {
3382            channel_id: channel_id.to_proto(),
3383            collaborators: left_buffer.collaborators,
3384        },
3385        &session.peer,
3386    );
3387
3388    Ok(())
3389}
3390
3391fn channel_buffer_updated<T: EnvelopedMessage>(
3392    sender_id: ConnectionId,
3393    collaborators: impl IntoIterator<Item = ConnectionId>,
3394    message: &T,
3395    peer: &Peer,
3396) {
3397    broadcast(Some(sender_id), collaborators, |peer_id| {
3398        peer.send(peer_id, message.clone())
3399    });
3400}
3401
3402fn send_notifications(
3403    connection_pool: &ConnectionPool,
3404    peer: &Peer,
3405    notifications: db::NotificationBatch,
3406) {
3407    for (user_id, notification) in notifications {
3408        for connection_id in connection_pool.user_connection_ids(user_id) {
3409            if let Err(error) = peer.send(
3410                connection_id,
3411                proto::AddNotification {
3412                    notification: Some(notification.clone()),
3413                },
3414            ) {
3415                tracing::error!(
3416                    "failed to send notification to {:?} {}",
3417                    connection_id,
3418                    error
3419                );
3420            }
3421        }
3422    }
3423}
3424
3425/// Send a message to the channel
3426async fn send_channel_message(
3427    request: proto::SendChannelMessage,
3428    response: Response<proto::SendChannelMessage>,
3429    session: Session,
3430) -> Result<()> {
3431    // Validate the message body.
3432    let body = request.body.trim().to_string();
3433    if body.len() > MAX_MESSAGE_LEN {
3434        return Err(anyhow!("message is too long"))?;
3435    }
3436    if body.is_empty() {
3437        return Err(anyhow!("message can't be blank"))?;
3438    }
3439
3440    // TODO: adjust mentions if body is trimmed
3441
3442    let timestamp = OffsetDateTime::now_utc();
3443    let nonce = request
3444        .nonce
3445        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3446
3447    let channel_id = ChannelId::from_proto(request.channel_id);
3448    let CreatedChannelMessage {
3449        message_id,
3450        participant_connection_ids,
3451        notifications,
3452    } = session
3453        .db()
3454        .await
3455        .create_channel_message(
3456            channel_id,
3457            session.user_id(),
3458            &body,
3459            &request.mentions,
3460            timestamp,
3461            nonce.clone().into(),
3462            request.reply_to_message_id.map(MessageId::from_proto),
3463        )
3464        .await?;
3465
3466    let message = proto::ChannelMessage {
3467        sender_id: session.user_id().to_proto(),
3468        id: message_id.to_proto(),
3469        body,
3470        mentions: request.mentions,
3471        timestamp: timestamp.unix_timestamp() as u64,
3472        nonce: Some(nonce),
3473        reply_to_message_id: request.reply_to_message_id,
3474        edited_at: None,
3475    };
3476    broadcast(
3477        Some(session.connection_id),
3478        participant_connection_ids.clone(),
3479        |connection| {
3480            session.peer.send(
3481                connection,
3482                proto::ChannelMessageSent {
3483                    channel_id: channel_id.to_proto(),
3484                    message: Some(message.clone()),
3485                },
3486            )
3487        },
3488    );
3489    response.send(proto::SendChannelMessageResponse {
3490        message: Some(message),
3491    })?;
3492
3493    let pool = &*session.connection_pool().await;
3494    let non_participants =
3495        pool.channel_connection_ids(channel_id)
3496            .filter_map(|(connection_id, _)| {
3497                if participant_connection_ids.contains(&connection_id) {
3498                    None
3499                } else {
3500                    Some(connection_id)
3501                }
3502            });
3503    broadcast(None, non_participants, |peer_id| {
3504        session.peer.send(
3505            peer_id,
3506            proto::UpdateChannels {
3507                latest_channel_message_ids: vec![proto::ChannelMessageId {
3508                    channel_id: channel_id.to_proto(),
3509                    message_id: message_id.to_proto(),
3510                }],
3511                ..Default::default()
3512            },
3513        )
3514    });
3515    send_notifications(pool, &session.peer, notifications);
3516
3517    Ok(())
3518}
3519
3520/// Delete a channel message
3521async fn remove_channel_message(
3522    request: proto::RemoveChannelMessage,
3523    response: Response<proto::RemoveChannelMessage>,
3524    session: Session,
3525) -> Result<()> {
3526    let channel_id = ChannelId::from_proto(request.channel_id);
3527    let message_id = MessageId::from_proto(request.message_id);
3528    let (connection_ids, existing_notification_ids) = session
3529        .db()
3530        .await
3531        .remove_channel_message(channel_id, message_id, session.user_id())
3532        .await?;
3533
3534    broadcast(
3535        Some(session.connection_id),
3536        connection_ids,
3537        move |connection| {
3538            session.peer.send(connection, request.clone())?;
3539
3540            for notification_id in &existing_notification_ids {
3541                session.peer.send(
3542                    connection,
3543                    proto::DeleteNotification {
3544                        notification_id: (*notification_id).to_proto(),
3545                    },
3546                )?;
3547            }
3548
3549            Ok(())
3550        },
3551    );
3552    response.send(proto::Ack {})?;
3553    Ok(())
3554}
3555
3556async fn update_channel_message(
3557    request: proto::UpdateChannelMessage,
3558    response: Response<proto::UpdateChannelMessage>,
3559    session: Session,
3560) -> Result<()> {
3561    let channel_id = ChannelId::from_proto(request.channel_id);
3562    let message_id = MessageId::from_proto(request.message_id);
3563    let updated_at = OffsetDateTime::now_utc();
3564    let UpdatedChannelMessage {
3565        message_id,
3566        participant_connection_ids,
3567        notifications,
3568        reply_to_message_id,
3569        timestamp,
3570        deleted_mention_notification_ids,
3571        updated_mention_notifications,
3572    } = session
3573        .db()
3574        .await
3575        .update_channel_message(
3576            channel_id,
3577            message_id,
3578            session.user_id(),
3579            request.body.as_str(),
3580            &request.mentions,
3581            updated_at,
3582        )
3583        .await?;
3584
3585    let nonce = request
3586        .nonce
3587        .clone()
3588        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3589
3590    let message = proto::ChannelMessage {
3591        sender_id: session.user_id().to_proto(),
3592        id: message_id.to_proto(),
3593        body: request.body.clone(),
3594        mentions: request.mentions.clone(),
3595        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3596        nonce: Some(nonce),
3597        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3598        edited_at: Some(updated_at.unix_timestamp() as u64),
3599    };
3600
3601    response.send(proto::Ack {})?;
3602
3603    let pool = &*session.connection_pool().await;
3604    broadcast(
3605        Some(session.connection_id),
3606        participant_connection_ids,
3607        |connection| {
3608            session.peer.send(
3609                connection,
3610                proto::ChannelMessageUpdate {
3611                    channel_id: channel_id.to_proto(),
3612                    message: Some(message.clone()),
3613                },
3614            )?;
3615
3616            for notification_id in &deleted_mention_notification_ids {
3617                session.peer.send(
3618                    connection,
3619                    proto::DeleteNotification {
3620                        notification_id: (*notification_id).to_proto(),
3621                    },
3622                )?;
3623            }
3624
3625            for notification in &updated_mention_notifications {
3626                session.peer.send(
3627                    connection,
3628                    proto::UpdateNotification {
3629                        notification: Some(notification.clone()),
3630                    },
3631                )?;
3632            }
3633
3634            Ok(())
3635        },
3636    );
3637
3638    send_notifications(pool, &session.peer, notifications);
3639
3640    Ok(())
3641}
3642
3643/// Mark a channel message as read
3644async fn acknowledge_channel_message(
3645    request: proto::AckChannelMessage,
3646    session: Session,
3647) -> Result<()> {
3648    let channel_id = ChannelId::from_proto(request.channel_id);
3649    let message_id = MessageId::from_proto(request.message_id);
3650    let notifications = session
3651        .db()
3652        .await
3653        .observe_channel_message(channel_id, session.user_id(), message_id)
3654        .await?;
3655    send_notifications(
3656        &*session.connection_pool().await,
3657        &session.peer,
3658        notifications,
3659    );
3660    Ok(())
3661}
3662
3663/// Mark a buffer version as synced
3664async fn acknowledge_buffer_version(
3665    request: proto::AckBufferOperation,
3666    session: Session,
3667) -> Result<()> {
3668    let buffer_id = BufferId::from_proto(request.buffer_id);
3669    session
3670        .db()
3671        .await
3672        .observe_buffer_version(
3673            buffer_id,
3674            session.user_id(),
3675            request.epoch as i32,
3676            &request.version,
3677        )
3678        .await?;
3679    Ok(())
3680}
3681
3682/// Get a Supermaven API key for the user
3683async fn get_supermaven_api_key(
3684    _request: proto::GetSupermavenApiKey,
3685    response: Response<proto::GetSupermavenApiKey>,
3686    session: Session,
3687) -> Result<()> {
3688    let user_id: String = session.user_id().to_string();
3689    if !session.is_staff() {
3690        return Err(anyhow!("supermaven not enabled for this account"))?;
3691    }
3692
3693    let email = session
3694        .email()
3695        .ok_or_else(|| anyhow!("user must have an email"))?;
3696
3697    let supermaven_admin_api = session
3698        .supermaven_client
3699        .as_ref()
3700        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3701
3702    let result = supermaven_admin_api
3703        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3704        .await?;
3705
3706    response.send(proto::GetSupermavenApiKeyResponse {
3707        api_key: result.api_key,
3708    })?;
3709
3710    Ok(())
3711}
3712
3713/// Start receiving chat updates for a channel
3714async fn join_channel_chat(
3715    request: proto::JoinChannelChat,
3716    response: Response<proto::JoinChannelChat>,
3717    session: Session,
3718) -> Result<()> {
3719    let channel_id = ChannelId::from_proto(request.channel_id);
3720
3721    let db = session.db().await;
3722    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3723        .await?;
3724    let messages = db
3725        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3726        .await?;
3727    response.send(proto::JoinChannelChatResponse {
3728        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3729        messages,
3730    })?;
3731    Ok(())
3732}
3733
3734/// Stop receiving chat updates for a channel
3735async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3736    let channel_id = ChannelId::from_proto(request.channel_id);
3737    session
3738        .db()
3739        .await
3740        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3741        .await?;
3742    Ok(())
3743}
3744
3745/// Retrieve the chat history for a channel
3746async fn get_channel_messages(
3747    request: proto::GetChannelMessages,
3748    response: Response<proto::GetChannelMessages>,
3749    session: Session,
3750) -> Result<()> {
3751    let channel_id = ChannelId::from_proto(request.channel_id);
3752    let messages = session
3753        .db()
3754        .await
3755        .get_channel_messages(
3756            channel_id,
3757            session.user_id(),
3758            MESSAGE_COUNT_PER_PAGE,
3759            Some(MessageId::from_proto(request.before_message_id)),
3760        )
3761        .await?;
3762    response.send(proto::GetChannelMessagesResponse {
3763        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3764        messages,
3765    })?;
3766    Ok(())
3767}
3768
3769/// Retrieve specific chat messages
3770async fn get_channel_messages_by_id(
3771    request: proto::GetChannelMessagesById,
3772    response: Response<proto::GetChannelMessagesById>,
3773    session: Session,
3774) -> Result<()> {
3775    let message_ids = request
3776        .message_ids
3777        .iter()
3778        .map(|id| MessageId::from_proto(*id))
3779        .collect::<Vec<_>>();
3780    let messages = session
3781        .db()
3782        .await
3783        .get_channel_messages_by_id(session.user_id(), &message_ids)
3784        .await?;
3785    response.send(proto::GetChannelMessagesResponse {
3786        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3787        messages,
3788    })?;
3789    Ok(())
3790}
3791
3792/// Retrieve the current users notifications
3793async fn get_notifications(
3794    request: proto::GetNotifications,
3795    response: Response<proto::GetNotifications>,
3796    session: Session,
3797) -> Result<()> {
3798    let notifications = session
3799        .db()
3800        .await
3801        .get_notifications(
3802            session.user_id(),
3803            NOTIFICATION_COUNT_PER_PAGE,
3804            request.before_id.map(db::NotificationId::from_proto),
3805        )
3806        .await?;
3807    response.send(proto::GetNotificationsResponse {
3808        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3809        notifications,
3810    })?;
3811    Ok(())
3812}
3813
3814/// Mark notifications as read
3815async fn mark_notification_as_read(
3816    request: proto::MarkNotificationRead,
3817    response: Response<proto::MarkNotificationRead>,
3818    session: Session,
3819) -> Result<()> {
3820    let database = &session.db().await;
3821    let notifications = database
3822        .mark_notification_as_read_by_id(
3823            session.user_id(),
3824            NotificationId::from_proto(request.notification_id),
3825        )
3826        .await?;
3827    send_notifications(
3828        &*session.connection_pool().await,
3829        &session.peer,
3830        notifications,
3831    );
3832    response.send(proto::Ack {})?;
3833    Ok(())
3834}
3835
3836/// Get the current users information
3837async fn get_private_user_info(
3838    _request: proto::GetPrivateUserInfo,
3839    response: Response<proto::GetPrivateUserInfo>,
3840    session: Session,
3841) -> Result<()> {
3842    let db = session.db().await;
3843
3844    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3845    let user = db
3846        .get_user_by_id(session.user_id())
3847        .await?
3848        .ok_or_else(|| anyhow!("user not found"))?;
3849    let flags = db.get_user_flags(session.user_id()).await?;
3850
3851    response.send(proto::GetPrivateUserInfoResponse {
3852        metrics_id,
3853        staff: user.admin,
3854        flags,
3855        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3856    })?;
3857    Ok(())
3858}
3859
3860/// Accept the terms of service (tos) on behalf of the current user
3861async fn accept_terms_of_service(
3862    _request: proto::AcceptTermsOfService,
3863    response: Response<proto::AcceptTermsOfService>,
3864    session: Session,
3865) -> Result<()> {
3866    let db = session.db().await;
3867
3868    let accepted_tos_at = Utc::now();
3869    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3870        .await?;
3871
3872    response.send(proto::AcceptTermsOfServiceResponse {
3873        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3874    })?;
3875    Ok(())
3876}
3877
3878/// The minimum account age an account must have in order to use the LLM service.
3879pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3880
3881async fn get_llm_api_token(
3882    _request: proto::GetLlmToken,
3883    response: Response<proto::GetLlmToken>,
3884    session: Session,
3885) -> Result<()> {
3886    let db = session.db().await;
3887
3888    let flags = db.get_user_flags(session.user_id()).await?;
3889    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3890
3891    if !session.is_staff() && !has_language_models_feature_flag {
3892        Err(anyhow!("permission denied"))?
3893    }
3894
3895    let user_id = session.user_id();
3896    let user = db
3897        .get_user_by_id(user_id)
3898        .await?
3899        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3900
3901    if user.accepted_tos_at.is_none() {
3902        Err(anyhow!("terms of service not accepted"))?
3903    }
3904
3905    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
3906    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
3907    let billing_preferences = db.get_billing_preferences(user.id).await?;
3908
3909    let token = LlmTokenClaims::create(
3910        &user,
3911        session.is_staff(),
3912        billing_preferences,
3913        &flags,
3914        has_legacy_llm_subscription,
3915        billing_subscription,
3916        session.system_id.clone(),
3917        &session.app_state.config,
3918    )?;
3919    response.send(proto::GetLlmTokenResponse { token })?;
3920    Ok(())
3921}
3922
3923fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3924    let message = match message {
3925        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3926        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3927        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3928        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3929        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3930            code: frame.code.into(),
3931            reason: frame.reason.as_str().to_owned().into(),
3932        })),
3933        // We should never receive a frame while reading the message, according
3934        // to the `tungstenite` maintainers:
3935        //
3936        // > It cannot occur when you read messages from the WebSocket, but it
3937        // > can be used when you want to send the raw frames (e.g. you want to
3938        // > send the frames to the WebSocket without composing the full message first).
3939        // >
3940        // > — https://github.com/snapview/tungstenite-rs/issues/268
3941        TungsteniteMessage::Frame(_) => {
3942            bail!("received an unexpected frame while reading the message")
3943        }
3944    };
3945
3946    Ok(message)
3947}
3948
3949fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3950    match message {
3951        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3952        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3953        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3954        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3955        AxumMessage::Close(frame) => {
3956            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3957                code: frame.code.into(),
3958                reason: frame.reason.as_ref().into(),
3959            }))
3960        }
3961    }
3962}
3963
3964fn notify_membership_updated(
3965    connection_pool: &mut ConnectionPool,
3966    result: MembershipUpdated,
3967    user_id: UserId,
3968    peer: &Peer,
3969) {
3970    for membership in &result.new_channels.channel_memberships {
3971        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3972    }
3973    for channel_id in &result.removed_channels {
3974        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3975    }
3976
3977    let user_channels_update = proto::UpdateUserChannels {
3978        channel_memberships: result
3979            .new_channels
3980            .channel_memberships
3981            .iter()
3982            .map(|cm| proto::ChannelMembership {
3983                channel_id: cm.channel_id.to_proto(),
3984                role: cm.role.into(),
3985            })
3986            .collect(),
3987        ..Default::default()
3988    };
3989
3990    let mut update = build_channels_update(result.new_channels);
3991    update.delete_channels = result
3992        .removed_channels
3993        .into_iter()
3994        .map(|id| id.to_proto())
3995        .collect();
3996    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3997
3998    for connection_id in connection_pool.user_connection_ids(user_id) {
3999        peer.send(connection_id, user_channels_update.clone())
4000            .trace_err();
4001        peer.send(connection_id, update.clone()).trace_err();
4002    }
4003}
4004
4005fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4006    proto::UpdateUserChannels {
4007        channel_memberships: channels
4008            .channel_memberships
4009            .iter()
4010            .map(|m| proto::ChannelMembership {
4011                channel_id: m.channel_id.to_proto(),
4012                role: m.role.into(),
4013            })
4014            .collect(),
4015        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4016        observed_channel_message_id: channels.observed_channel_messages.clone(),
4017    }
4018}
4019
4020fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4021    let mut update = proto::UpdateChannels::default();
4022
4023    for channel in channels.channels {
4024        update.channels.push(channel.to_proto());
4025    }
4026
4027    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4028    update.latest_channel_message_ids = channels.latest_channel_messages;
4029
4030    for (channel_id, participants) in channels.channel_participants {
4031        update
4032            .channel_participants
4033            .push(proto::ChannelParticipants {
4034                channel_id: channel_id.to_proto(),
4035                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4036            });
4037    }
4038
4039    for channel in channels.invited_channels {
4040        update.channel_invitations.push(channel.to_proto());
4041    }
4042
4043    update
4044}
4045
4046fn build_initial_contacts_update(
4047    contacts: Vec<db::Contact>,
4048    pool: &ConnectionPool,
4049) -> proto::UpdateContacts {
4050    let mut update = proto::UpdateContacts::default();
4051
4052    for contact in contacts {
4053        match contact {
4054            db::Contact::Accepted { user_id, busy } => {
4055                update.contacts.push(contact_for_user(user_id, busy, pool));
4056            }
4057            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4058            db::Contact::Incoming { user_id } => {
4059                update
4060                    .incoming_requests
4061                    .push(proto::IncomingContactRequest {
4062                        requester_id: user_id.to_proto(),
4063                    })
4064            }
4065        }
4066    }
4067
4068    update
4069}
4070
4071fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4072    proto::Contact {
4073        user_id: user_id.to_proto(),
4074        online: pool.is_user_online(user_id),
4075        busy,
4076    }
4077}
4078
4079fn room_updated(room: &proto::Room, peer: &Peer) {
4080    broadcast(
4081        None,
4082        room.participants
4083            .iter()
4084            .filter_map(|participant| Some(participant.peer_id?.into())),
4085        |peer_id| {
4086            peer.send(
4087                peer_id,
4088                proto::RoomUpdated {
4089                    room: Some(room.clone()),
4090                },
4091            )
4092        },
4093    );
4094}
4095
4096fn channel_updated(
4097    channel: &db::channel::Model,
4098    room: &proto::Room,
4099    peer: &Peer,
4100    pool: &ConnectionPool,
4101) {
4102    let participants = room
4103        .participants
4104        .iter()
4105        .map(|p| p.user_id)
4106        .collect::<Vec<_>>();
4107
4108    broadcast(
4109        None,
4110        pool.channel_connection_ids(channel.root_id())
4111            .filter_map(|(channel_id, role)| {
4112                role.can_see_channel(channel.visibility)
4113                    .then_some(channel_id)
4114            }),
4115        |peer_id| {
4116            peer.send(
4117                peer_id,
4118                proto::UpdateChannels {
4119                    channel_participants: vec![proto::ChannelParticipants {
4120                        channel_id: channel.id.to_proto(),
4121                        participant_user_ids: participants.clone(),
4122                    }],
4123                    ..Default::default()
4124                },
4125            )
4126        },
4127    );
4128}
4129
4130async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4131    let db = session.db().await;
4132
4133    let contacts = db.get_contacts(user_id).await?;
4134    let busy = db.is_user_busy(user_id).await?;
4135
4136    let pool = session.connection_pool().await;
4137    let updated_contact = contact_for_user(user_id, busy, &pool);
4138    for contact in contacts {
4139        if let db::Contact::Accepted {
4140            user_id: contact_user_id,
4141            ..
4142        } = contact
4143        {
4144            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4145                session
4146                    .peer
4147                    .send(
4148                        contact_conn_id,
4149                        proto::UpdateContacts {
4150                            contacts: vec![updated_contact.clone()],
4151                            remove_contacts: Default::default(),
4152                            incoming_requests: Default::default(),
4153                            remove_incoming_requests: Default::default(),
4154                            outgoing_requests: Default::default(),
4155                            remove_outgoing_requests: Default::default(),
4156                        },
4157                    )
4158                    .trace_err();
4159            }
4160        }
4161    }
4162    Ok(())
4163}
4164
4165async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4166    let mut contacts_to_update = HashSet::default();
4167
4168    let room_id;
4169    let canceled_calls_to_user_ids;
4170    let livekit_room;
4171    let delete_livekit_room;
4172    let room;
4173    let channel;
4174
4175    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4176        contacts_to_update.insert(session.user_id());
4177
4178        for project in left_room.left_projects.values() {
4179            project_left(project, session);
4180        }
4181
4182        room_id = RoomId::from_proto(left_room.room.id);
4183        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4184        livekit_room = mem::take(&mut left_room.room.livekit_room);
4185        delete_livekit_room = left_room.deleted;
4186        room = mem::take(&mut left_room.room);
4187        channel = mem::take(&mut left_room.channel);
4188
4189        room_updated(&room, &session.peer);
4190    } else {
4191        return Ok(());
4192    }
4193
4194    if let Some(channel) = channel {
4195        channel_updated(
4196            &channel,
4197            &room,
4198            &session.peer,
4199            &*session.connection_pool().await,
4200        );
4201    }
4202
4203    {
4204        let pool = session.connection_pool().await;
4205        for canceled_user_id in canceled_calls_to_user_ids {
4206            for connection_id in pool.user_connection_ids(canceled_user_id) {
4207                session
4208                    .peer
4209                    .send(
4210                        connection_id,
4211                        proto::CallCanceled {
4212                            room_id: room_id.to_proto(),
4213                        },
4214                    )
4215                    .trace_err();
4216            }
4217            contacts_to_update.insert(canceled_user_id);
4218        }
4219    }
4220
4221    for contact_user_id in contacts_to_update {
4222        update_user_contacts(contact_user_id, session).await?;
4223    }
4224
4225    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4226        live_kit
4227            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4228            .await
4229            .trace_err();
4230
4231        if delete_livekit_room {
4232            live_kit.delete_room(livekit_room).await.trace_err();
4233        }
4234    }
4235
4236    Ok(())
4237}
4238
4239async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4240    let left_channel_buffers = session
4241        .db()
4242        .await
4243        .leave_channel_buffers(session.connection_id)
4244        .await?;
4245
4246    for left_buffer in left_channel_buffers {
4247        channel_buffer_updated(
4248            session.connection_id,
4249            left_buffer.connections,
4250            &proto::UpdateChannelBufferCollaborators {
4251                channel_id: left_buffer.channel_id.to_proto(),
4252                collaborators: left_buffer.collaborators,
4253            },
4254            &session.peer,
4255        );
4256    }
4257
4258    Ok(())
4259}
4260
4261fn project_left(project: &db::LeftProject, session: &Session) {
4262    for connection_id in &project.connection_ids {
4263        if project.should_unshare {
4264            session
4265                .peer
4266                .send(
4267                    *connection_id,
4268                    proto::UnshareProject {
4269                        project_id: project.id.to_proto(),
4270                    },
4271                )
4272                .trace_err();
4273        } else {
4274            session
4275                .peer
4276                .send(
4277                    *connection_id,
4278                    proto::RemoveProjectCollaborator {
4279                        project_id: project.id.to_proto(),
4280                        peer_id: Some(session.connection_id.into()),
4281                    },
4282                )
4283                .trace_err();
4284        }
4285    }
4286}
4287
4288pub trait ResultExt {
4289    type Ok;
4290
4291    fn trace_err(self) -> Option<Self::Ok>;
4292}
4293
4294impl<T, E> ResultExt for Result<T, E>
4295where
4296    E: std::fmt::Debug,
4297{
4298    type Ok = T;
4299
4300    #[track_caller]
4301    fn trace_err(self) -> Option<T> {
4302        match self {
4303            Ok(value) => Some(value),
4304            Err(error) => {
4305                tracing::error!("{:?}", error);
4306                None
4307            }
4308        }
4309    }
4310}