rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Error, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use reqwest_client::ReqwestClient;
  37use rpc::proto::split_repository_update;
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use time::OffsetDateTime;
  69use tokio::sync::{MutexGuard, Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103#[derive(Clone, Debug)]
 104pub enum Principal {
 105    User(User),
 106    Impersonated { user: User, admin: User },
 107}
 108
 109impl Principal {
 110    fn update_span(&self, span: &tracing::Span) {
 111        match &self {
 112            Principal::User(user) => {
 113                span.record("user_id", user.id.0);
 114                span.record("login", &user.github_login);
 115            }
 116            Principal::Impersonated { user, admin } => {
 117                span.record("user_id", user.id.0);
 118                span.record("login", &user.github_login);
 119                span.record("impersonator", &admin.github_login);
 120            }
 121        }
 122    }
 123}
 124
 125#[derive(Clone)]
 126struct Session {
 127    principal: Principal,
 128    connection_id: ConnectionId,
 129    db: Arc<tokio::sync::Mutex<DbHandle>>,
 130    peer: Arc<Peer>,
 131    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 132    app_state: Arc<AppState>,
 133    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 134    /// The GeoIP country code for the user.
 135    #[allow(unused)]
 136    geoip_country_code: Option<String>,
 137    system_id: Option<String>,
 138    _executor: Executor,
 139}
 140
 141impl Session {
 142    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 143        #[cfg(test)]
 144        tokio::task::yield_now().await;
 145        let guard = self.db.lock().await;
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        guard
 149    }
 150
 151    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.connection_pool.lock();
 155        ConnectionPoolGuard {
 156            guard,
 157            _not_send: PhantomData,
 158        }
 159    }
 160
 161    fn is_staff(&self) -> bool {
 162        match &self.principal {
 163            Principal::User(user) => user.admin,
 164            Principal::Impersonated { .. } => true,
 165        }
 166    }
 167
 168    pub async fn has_llm_subscription(
 169        &self,
 170        db: &MutexGuard<'_, DbHandle>,
 171    ) -> anyhow::Result<bool> {
 172        if self.is_staff() {
 173            return Ok(true);
 174        }
 175
 176        let user_id = self.user_id();
 177
 178        Ok(db.has_active_billing_subscription(user_id).await?)
 179    }
 180
 181    pub async fn current_plan(
 182        &self,
 183        _db: &MutexGuard<'_, DbHandle>,
 184    ) -> anyhow::Result<proto::Plan> {
 185        if self.is_staff() {
 186            Ok(proto::Plan::ZedPro)
 187        } else {
 188            Ok(proto::Plan::Free)
 189        }
 190    }
 191
 192    fn user_id(&self) -> UserId {
 193        match &self.principal {
 194            Principal::User(user) => user.id,
 195            Principal::Impersonated { user, .. } => user.id,
 196        }
 197    }
 198
 199    pub fn email(&self) -> Option<String> {
 200        match &self.principal {
 201            Principal::User(user) => user.email_address.clone(),
 202            Principal::Impersonated { user, .. } => user.email_address.clone(),
 203        }
 204    }
 205}
 206
 207impl Debug for Session {
 208    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 209        let mut result = f.debug_struct("Session");
 210        match &self.principal {
 211            Principal::User(user) => {
 212                result.field("user", &user.github_login);
 213            }
 214            Principal::Impersonated { user, admin } => {
 215                result.field("user", &user.github_login);
 216                result.field("impersonator", &admin.github_login);
 217            }
 218        }
 219        result.field("connection_id", &self.connection_id).finish()
 220    }
 221}
 222
 223struct DbHandle(Arc<Database>);
 224
 225impl Deref for DbHandle {
 226    type Target = Database;
 227
 228    fn deref(&self) -> &Self::Target {
 229        self.0.as_ref()
 230    }
 231}
 232
 233pub struct Server {
 234    id: parking_lot::Mutex<ServerId>,
 235    peer: Arc<Peer>,
 236    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 237    app_state: Arc<AppState>,
 238    handlers: HashMap<TypeId, MessageHandler>,
 239    teardown: watch::Sender<bool>,
 240}
 241
 242pub(crate) struct ConnectionPoolGuard<'a> {
 243    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 244    _not_send: PhantomData<Rc<()>>,
 245}
 246
 247#[derive(Serialize)]
 248pub struct ServerSnapshot<'a> {
 249    peer: &'a Peer,
 250    #[serde(serialize_with = "serialize_deref")]
 251    connection_pool: ConnectionPoolGuard<'a>,
 252}
 253
 254pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 255where
 256    S: Serializer,
 257    T: Deref<Target = U>,
 258    U: Serialize,
 259{
 260    Serialize::serialize(value.deref(), serializer)
 261}
 262
 263impl Server {
 264    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 265        let mut server = Self {
 266            id: parking_lot::Mutex::new(id),
 267            peer: Peer::new(id.0 as u32),
 268            app_state: app_state.clone(),
 269            connection_pool: Default::default(),
 270            handlers: Default::default(),
 271            teardown: watch::channel(false).0,
 272        };
 273
 274        server
 275            .add_request_handler(ping)
 276            .add_request_handler(create_room)
 277            .add_request_handler(join_room)
 278            .add_request_handler(rejoin_room)
 279            .add_request_handler(leave_room)
 280            .add_request_handler(set_room_participant_role)
 281            .add_request_handler(call)
 282            .add_request_handler(cancel_call)
 283            .add_message_handler(decline_call)
 284            .add_request_handler(update_participant_location)
 285            .add_request_handler(share_project)
 286            .add_message_handler(unshare_project)
 287            .add_request_handler(join_project)
 288            .add_message_handler(leave_project)
 289            .add_request_handler(update_project)
 290            .add_request_handler(update_worktree)
 291            .add_request_handler(update_repository)
 292            .add_request_handler(remove_repository)
 293            .add_message_handler(start_language_server)
 294            .add_message_handler(update_language_server)
 295            .add_message_handler(update_diagnostic_summary)
 296            .add_message_handler(update_worktree_settings)
 297            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 301            .add_request_handler(forward_find_search_candidates_request)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 315            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 316            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 317            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 318            .add_request_handler(
 319                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 320            )
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 323            )
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 329            .add_request_handler(
 330                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 331            )
 332            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 338            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 339            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 340            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 341            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 342            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 343            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 345            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 346            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 348            .add_request_handler(
 349                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 350            )
 351            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 352            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 353            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 354            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 355            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 356            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 357            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 358            .add_message_handler(create_buffer_for_peer)
 359            .add_request_handler(update_buffer)
 360            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 361            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 362            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 363            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 366            .add_request_handler(get_users)
 367            .add_request_handler(fuzzy_search_users)
 368            .add_request_handler(request_contact)
 369            .add_request_handler(remove_contact)
 370            .add_request_handler(respond_to_contact_request)
 371            .add_message_handler(subscribe_to_channels)
 372            .add_request_handler(create_channel)
 373            .add_request_handler(delete_channel)
 374            .add_request_handler(invite_channel_member)
 375            .add_request_handler(remove_channel_member)
 376            .add_request_handler(set_channel_member_role)
 377            .add_request_handler(set_channel_visibility)
 378            .add_request_handler(rename_channel)
 379            .add_request_handler(join_channel_buffer)
 380            .add_request_handler(leave_channel_buffer)
 381            .add_message_handler(update_channel_buffer)
 382            .add_request_handler(rejoin_channel_buffers)
 383            .add_request_handler(get_channel_members)
 384            .add_request_handler(respond_to_channel_invite)
 385            .add_request_handler(join_channel)
 386            .add_request_handler(join_channel_chat)
 387            .add_message_handler(leave_channel_chat)
 388            .add_request_handler(send_channel_message)
 389            .add_request_handler(remove_channel_message)
 390            .add_request_handler(update_channel_message)
 391            .add_request_handler(get_channel_messages)
 392            .add_request_handler(get_channel_messages_by_id)
 393            .add_request_handler(get_notifications)
 394            .add_request_handler(mark_notification_as_read)
 395            .add_request_handler(move_channel)
 396            .add_request_handler(follow)
 397            .add_message_handler(unfollow)
 398            .add_message_handler(update_followers)
 399            .add_request_handler(get_private_user_info)
 400            .add_request_handler(get_llm_api_token)
 401            .add_request_handler(accept_terms_of_service)
 402            .add_message_handler(acknowledge_channel_message)
 403            .add_message_handler(acknowledge_buffer_version)
 404            .add_request_handler(get_supermaven_api_key)
 405            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 406            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 407            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 408            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 409            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 410            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 411            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 412            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 413            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 414            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 416            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 417            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 418            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 419            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 420            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 421            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 422            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 424            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 426            .add_message_handler(update_context);
 427
 428        Arc::new(server)
 429    }
 430
 431    pub async fn start(&self) -> Result<()> {
 432        let server_id = *self.id.lock();
 433        let app_state = self.app_state.clone();
 434        let peer = self.peer.clone();
 435        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 436        let pool = self.connection_pool.clone();
 437        let livekit_client = self.app_state.livekit_client.clone();
 438
 439        let span = info_span!("start server");
 440        self.app_state.executor.spawn_detached(
 441            async move {
 442                tracing::info!("waiting for cleanup timeout");
 443                timeout.await;
 444                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 445                if let Some((room_ids, channel_ids)) = app_state
 446                    .db
 447                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 448                    .await
 449                    .trace_err()
 450                {
 451                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 452                    tracing::info!(
 453                        stale_channel_buffer_count = channel_ids.len(),
 454                        "retrieved stale channel buffers"
 455                    );
 456
 457                    for channel_id in channel_ids {
 458                        if let Some(refreshed_channel_buffer) = app_state
 459                            .db
 460                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 461                            .await
 462                            .trace_err()
 463                        {
 464                            for connection_id in refreshed_channel_buffer.connection_ids {
 465                                peer.send(
 466                                    connection_id,
 467                                    proto::UpdateChannelBufferCollaborators {
 468                                        channel_id: channel_id.to_proto(),
 469                                        collaborators: refreshed_channel_buffer
 470                                            .collaborators
 471                                            .clone(),
 472                                    },
 473                                )
 474                                .trace_err();
 475                            }
 476                        }
 477                    }
 478
 479                    for room_id in room_ids {
 480                        let mut contacts_to_update = HashSet::default();
 481                        let mut canceled_calls_to_user_ids = Vec::new();
 482                        let mut livekit_room = String::new();
 483                        let mut delete_livekit_room = false;
 484
 485                        if let Some(mut refreshed_room) = app_state
 486                            .db
 487                            .clear_stale_room_participants(room_id, server_id)
 488                            .await
 489                            .trace_err()
 490                        {
 491                            tracing::info!(
 492                                room_id = room_id.0,
 493                                new_participant_count = refreshed_room.room.participants.len(),
 494                                "refreshed room"
 495                            );
 496                            room_updated(&refreshed_room.room, &peer);
 497                            if let Some(channel) = refreshed_room.channel.as_ref() {
 498                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 499                            }
 500                            contacts_to_update
 501                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 502                            contacts_to_update
 503                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 504                            canceled_calls_to_user_ids =
 505                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 506                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 507                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 508                        }
 509
 510                        {
 511                            let pool = pool.lock();
 512                            for canceled_user_id in canceled_calls_to_user_ids {
 513                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 514                                    peer.send(
 515                                        connection_id,
 516                                        proto::CallCanceled {
 517                                            room_id: room_id.to_proto(),
 518                                        },
 519                                    )
 520                                    .trace_err();
 521                                }
 522                            }
 523                        }
 524
 525                        for user_id in contacts_to_update {
 526                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 527                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 528                            if let Some((busy, contacts)) = busy.zip(contacts) {
 529                                let pool = pool.lock();
 530                                let updated_contact = contact_for_user(user_id, busy, &pool);
 531                                for contact in contacts {
 532                                    if let db::Contact::Accepted {
 533                                        user_id: contact_user_id,
 534                                        ..
 535                                    } = contact
 536                                    {
 537                                        for contact_conn_id in
 538                                            pool.user_connection_ids(contact_user_id)
 539                                        {
 540                                            peer.send(
 541                                                contact_conn_id,
 542                                                proto::UpdateContacts {
 543                                                    contacts: vec![updated_contact.clone()],
 544                                                    remove_contacts: Default::default(),
 545                                                    incoming_requests: Default::default(),
 546                                                    remove_incoming_requests: Default::default(),
 547                                                    outgoing_requests: Default::default(),
 548                                                    remove_outgoing_requests: Default::default(),
 549                                                },
 550                                            )
 551                                            .trace_err();
 552                                        }
 553                                    }
 554                                }
 555                            }
 556                        }
 557
 558                        if let Some(live_kit) = livekit_client.as_ref() {
 559                            if delete_livekit_room {
 560                                live_kit.delete_room(livekit_room).await.trace_err();
 561                            }
 562                        }
 563                    }
 564                }
 565
 566                app_state
 567                    .db
 568                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 569                    .await
 570                    .trace_err();
 571            }
 572            .instrument(span),
 573        );
 574        Ok(())
 575    }
 576
 577    pub fn teardown(&self) {
 578        self.peer.teardown();
 579        self.connection_pool.lock().reset();
 580        let _ = self.teardown.send(true);
 581    }
 582
 583    #[cfg(test)]
 584    pub fn reset(&self, id: ServerId) {
 585        self.teardown();
 586        *self.id.lock() = id;
 587        self.peer.reset(id.0 as u32);
 588        let _ = self.teardown.send(false);
 589    }
 590
 591    #[cfg(test)]
 592    pub fn id(&self) -> ServerId {
 593        *self.id.lock()
 594    }
 595
 596    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 597    where
 598        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 599        Fut: 'static + Send + Future<Output = Result<()>>,
 600        M: EnvelopedMessage,
 601    {
 602        let prev_handler = self.handlers.insert(
 603            TypeId::of::<M>(),
 604            Box::new(move |envelope, session| {
 605                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 606                let received_at = envelope.received_at;
 607                tracing::info!("message received");
 608                let start_time = Instant::now();
 609                let future = (handler)(*envelope, session);
 610                async move {
 611                    let result = future.await;
 612                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 613                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 614                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 615                    let payload_type = M::NAME;
 616
 617                    match result {
 618                        Err(error) => {
 619                            tracing::error!(
 620                                ?error,
 621                                total_duration_ms,
 622                                processing_duration_ms,
 623                                queue_duration_ms,
 624                                payload_type,
 625                                "error handling message"
 626                            )
 627                        }
 628                        Ok(()) => tracing::info!(
 629                            total_duration_ms,
 630                            processing_duration_ms,
 631                            queue_duration_ms,
 632                            "finished handling message"
 633                        ),
 634                    }
 635                }
 636                .boxed()
 637            }),
 638        );
 639        if prev_handler.is_some() {
 640            panic!("registered a handler for the same message twice");
 641        }
 642        self
 643    }
 644
 645    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 646    where
 647        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 648        Fut: 'static + Send + Future<Output = Result<()>>,
 649        M: EnvelopedMessage,
 650    {
 651        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 652        self
 653    }
 654
 655    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 656    where
 657        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 658        Fut: Send + Future<Output = Result<()>>,
 659        M: RequestMessage,
 660    {
 661        let handler = Arc::new(handler);
 662        self.add_handler(move |envelope, session| {
 663            let receipt = envelope.receipt();
 664            let handler = handler.clone();
 665            async move {
 666                let peer = session.peer.clone();
 667                let responded = Arc::new(AtomicBool::default());
 668                let response = Response {
 669                    peer: peer.clone(),
 670                    responded: responded.clone(),
 671                    receipt,
 672                };
 673                match (handler)(envelope.payload, response, session).await {
 674                    Ok(()) => {
 675                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 676                            Ok(())
 677                        } else {
 678                            Err(anyhow!("handler did not send a response"))?
 679                        }
 680                    }
 681                    Err(error) => {
 682                        let proto_err = match &error {
 683                            Error::Internal(err) => err.to_proto(),
 684                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 685                        };
 686                        peer.respond_with_error(receipt, proto_err)?;
 687                        Err(error)
 688                    }
 689                }
 690            }
 691        })
 692    }
 693
 694    pub fn handle_connection(
 695        self: &Arc<Self>,
 696        connection: Connection,
 697        address: String,
 698        principal: Principal,
 699        zed_version: ZedVersion,
 700        geoip_country_code: Option<String>,
 701        system_id: Option<String>,
 702        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 703        executor: Executor,
 704    ) -> impl Future<Output = ()> + use<> {
 705        let this = self.clone();
 706        let span = info_span!("handle connection", %address,
 707            connection_id=field::Empty,
 708            user_id=field::Empty,
 709            login=field::Empty,
 710            impersonator=field::Empty,
 711            geoip_country_code=field::Empty
 712        );
 713        principal.update_span(&span);
 714        if let Some(country_code) = geoip_country_code.as_ref() {
 715            span.record("geoip_country_code", country_code);
 716        }
 717
 718        let mut teardown = self.teardown.subscribe();
 719        async move {
 720            if *teardown.borrow() {
 721                tracing::error!("server is tearing down");
 722                return
 723            }
 724            let (connection_id, handle_io, mut incoming_rx) = this
 725                .peer
 726                .add_connection(connection, {
 727                    let executor = executor.clone();
 728                    move |duration| executor.sleep(duration)
 729                });
 730            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 731
 732            tracing::info!("connection opened");
 733
 734            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 735            let http_client = match ReqwestClient::user_agent(&user_agent) {
 736                Ok(http_client) => Arc::new(http_client),
 737                Err(error) => {
 738                    tracing::error!(?error, "failed to create HTTP client");
 739                    return;
 740                }
 741            };
 742
 743            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 744                    supermaven_admin_api_key.to_string(),
 745                    http_client.clone(),
 746                )));
 747
 748            let session = Session {
 749                principal: principal.clone(),
 750                connection_id,
 751                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 752                peer: this.peer.clone(),
 753                connection_pool: this.connection_pool.clone(),
 754                app_state: this.app_state.clone(),
 755                geoip_country_code,
 756                system_id,
 757                _executor: executor.clone(),
 758                supermaven_client,
 759            };
 760
 761            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 762                tracing::error!(?error, "failed to send initial client update");
 763                return;
 764            }
 765
 766            let handle_io = handle_io.fuse();
 767            futures::pin_mut!(handle_io);
 768
 769            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 770            // This prevents deadlocks when e.g., client A performs a request to client B and
 771            // client B performs a request to client A. If both clients stop processing further
 772            // messages until their respective request completes, they won't have a chance to
 773            // respond to the other client's request and cause a deadlock.
 774            //
 775            // This arrangement ensures we will attempt to process earlier messages first, but fall
 776            // back to processing messages arrived later in the spirit of making progress.
 777            let mut foreground_message_handlers = FuturesUnordered::new();
 778            let concurrent_handlers = Arc::new(Semaphore::new(256));
 779            loop {
 780                let next_message = async {
 781                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 782                    let message = incoming_rx.next().await;
 783                    (permit, message)
 784                }.fuse();
 785                futures::pin_mut!(next_message);
 786                futures::select_biased! {
 787                    _ = teardown.changed().fuse() => return,
 788                    result = handle_io => {
 789                        if let Err(error) = result {
 790                            tracing::error!(?error, "error handling I/O");
 791                        }
 792                        break;
 793                    }
 794                    _ = foreground_message_handlers.next() => {}
 795                    next_message = next_message => {
 796                        let (permit, message) = next_message;
 797                        if let Some(message) = message {
 798                            let type_name = message.payload_type_name();
 799                            // note: we copy all the fields from the parent span so we can query them in the logs.
 800                            // (https://github.com/tokio-rs/tracing/issues/2670).
 801                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 802                                user_id=field::Empty,
 803                                login=field::Empty,
 804                                impersonator=field::Empty,
 805                            );
 806                            principal.update_span(&span);
 807                            let span_enter = span.enter();
 808                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 809                                let is_background = message.is_background();
 810                                let handle_message = (handler)(message, session.clone());
 811                                drop(span_enter);
 812
 813                                let handle_message = async move {
 814                                    handle_message.await;
 815                                    drop(permit);
 816                                }.instrument(span);
 817                                if is_background {
 818                                    executor.spawn_detached(handle_message);
 819                                } else {
 820                                    foreground_message_handlers.push(handle_message);
 821                                }
 822                            } else {
 823                                tracing::error!("no message handler");
 824                            }
 825                        } else {
 826                            tracing::info!("connection closed");
 827                            break;
 828                        }
 829                    }
 830                }
 831            }
 832
 833            drop(foreground_message_handlers);
 834            tracing::info!("signing out");
 835            if let Err(error) = connection_lost(session, teardown, executor).await {
 836                tracing::error!(?error, "error signing out");
 837            }
 838
 839        }.instrument(span)
 840    }
 841
 842    async fn send_initial_client_update(
 843        &self,
 844        connection_id: ConnectionId,
 845        principal: &Principal,
 846        zed_version: ZedVersion,
 847        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 848        session: &Session,
 849    ) -> Result<()> {
 850        self.peer.send(
 851            connection_id,
 852            proto::Hello {
 853                peer_id: Some(connection_id.into()),
 854            },
 855        )?;
 856        tracing::info!("sent hello message");
 857        if let Some(send_connection_id) = send_connection_id.take() {
 858            let _ = send_connection_id.send(connection_id);
 859        }
 860
 861        match principal {
 862            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 863                if !user.connected_once {
 864                    self.peer.send(connection_id, proto::ShowContacts {})?;
 865                    self.app_state
 866                        .db
 867                        .set_user_connected_once(user.id, true)
 868                        .await?;
 869                }
 870
 871                update_user_plan(user.id, session).await?;
 872
 873                let contacts = self.app_state.db.get_contacts(user.id).await?;
 874
 875                {
 876                    let mut pool = self.connection_pool.lock();
 877                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 878                    self.peer.send(
 879                        connection_id,
 880                        build_initial_contacts_update(contacts, &pool),
 881                    )?;
 882                }
 883
 884                if should_auto_subscribe_to_channels(zed_version) {
 885                    subscribe_user_to_channels(user.id, session).await?;
 886                }
 887
 888                if let Some(incoming_call) =
 889                    self.app_state.db.incoming_call_for_user(user.id).await?
 890                {
 891                    self.peer.send(connection_id, incoming_call)?;
 892                }
 893
 894                update_user_contacts(user.id, session).await?;
 895            }
 896        }
 897
 898        Ok(())
 899    }
 900
 901    pub async fn invite_code_redeemed(
 902        self: &Arc<Self>,
 903        inviter_id: UserId,
 904        invitee_id: UserId,
 905    ) -> Result<()> {
 906        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 907            if let Some(code) = &user.invite_code {
 908                let pool = self.connection_pool.lock();
 909                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 910                for connection_id in pool.user_connection_ids(inviter_id) {
 911                    self.peer.send(
 912                        connection_id,
 913                        proto::UpdateContacts {
 914                            contacts: vec![invitee_contact.clone()],
 915                            ..Default::default()
 916                        },
 917                    )?;
 918                    self.peer.send(
 919                        connection_id,
 920                        proto::UpdateInviteInfo {
 921                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 922                            count: user.invite_count as u32,
 923                        },
 924                    )?;
 925                }
 926            }
 927        }
 928        Ok(())
 929    }
 930
 931    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 932        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 933            if let Some(invite_code) = &user.invite_code {
 934                let pool = self.connection_pool.lock();
 935                for connection_id in pool.user_connection_ids(user_id) {
 936                    self.peer.send(
 937                        connection_id,
 938                        proto::UpdateInviteInfo {
 939                            url: format!(
 940                                "{}{}",
 941                                self.app_state.config.invite_link_prefix, invite_code
 942                            ),
 943                            count: user.invite_count as u32,
 944                        },
 945                    )?;
 946                }
 947            }
 948        }
 949        Ok(())
 950    }
 951
 952    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 953        let pool = self.connection_pool.lock();
 954        for connection_id in pool.user_connection_ids(user_id) {
 955            self.peer
 956                .send(connection_id, proto::RefreshLlmToken {})
 957                .trace_err();
 958        }
 959    }
 960
 961    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 962        ServerSnapshot {
 963            connection_pool: ConnectionPoolGuard {
 964                guard: self.connection_pool.lock(),
 965                _not_send: PhantomData,
 966            },
 967            peer: &self.peer,
 968        }
 969    }
 970}
 971
 972impl Deref for ConnectionPoolGuard<'_> {
 973    type Target = ConnectionPool;
 974
 975    fn deref(&self) -> &Self::Target {
 976        &self.guard
 977    }
 978}
 979
 980impl DerefMut for ConnectionPoolGuard<'_> {
 981    fn deref_mut(&mut self) -> &mut Self::Target {
 982        &mut self.guard
 983    }
 984}
 985
 986impl Drop for ConnectionPoolGuard<'_> {
 987    fn drop(&mut self) {
 988        #[cfg(test)]
 989        self.check_invariants();
 990    }
 991}
 992
 993fn broadcast<F>(
 994    sender_id: Option<ConnectionId>,
 995    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 996    mut f: F,
 997) where
 998    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 999{
1000    for receiver_id in receiver_ids {
1001        if Some(receiver_id) != sender_id {
1002            if let Err(error) = f(receiver_id) {
1003                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1004            }
1005        }
1006    }
1007}
1008
1009pub struct ProtocolVersion(u32);
1010
1011impl Header for ProtocolVersion {
1012    fn name() -> &'static HeaderName {
1013        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1014        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1015    }
1016
1017    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1018    where
1019        Self: Sized,
1020        I: Iterator<Item = &'i axum::http::HeaderValue>,
1021    {
1022        let version = values
1023            .next()
1024            .ok_or_else(axum::headers::Error::invalid)?
1025            .to_str()
1026            .map_err(|_| axum::headers::Error::invalid())?
1027            .parse()
1028            .map_err(|_| axum::headers::Error::invalid())?;
1029        Ok(Self(version))
1030    }
1031
1032    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1033        values.extend([self.0.to_string().parse().unwrap()]);
1034    }
1035}
1036
1037pub struct AppVersionHeader(SemanticVersion);
1038impl Header for AppVersionHeader {
1039    fn name() -> &'static HeaderName {
1040        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1042    }
1043
1044    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045    where
1046        Self: Sized,
1047        I: Iterator<Item = &'i axum::http::HeaderValue>,
1048    {
1049        let version = values
1050            .next()
1051            .ok_or_else(axum::headers::Error::invalid)?
1052            .to_str()
1053            .map_err(|_| axum::headers::Error::invalid())?
1054            .parse()
1055            .map_err(|_| axum::headers::Error::invalid())?;
1056        Ok(Self(version))
1057    }
1058
1059    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060        values.extend([self.0.to_string().parse().unwrap()]);
1061    }
1062}
1063
1064pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1065    Router::new()
1066        .route("/rpc", get(handle_websocket_request))
1067        .layer(
1068            ServiceBuilder::new()
1069                .layer(Extension(server.app_state.clone()))
1070                .layer(middleware::from_fn(auth::validate_header)),
1071        )
1072        .route("/metrics", get(handle_metrics))
1073        .layer(Extension(server))
1074}
1075
1076pub async fn handle_websocket_request(
1077    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1078    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1079    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1080    Extension(server): Extension<Arc<Server>>,
1081    Extension(principal): Extension<Principal>,
1082    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1083    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1084    ws: WebSocketUpgrade,
1085) -> axum::response::Response {
1086    if protocol_version != rpc::PROTOCOL_VERSION {
1087        return (
1088            StatusCode::UPGRADE_REQUIRED,
1089            "client must be upgraded".to_string(),
1090        )
1091            .into_response();
1092    }
1093
1094    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1095        return (
1096            StatusCode::UPGRADE_REQUIRED,
1097            "no version header found".to_string(),
1098        )
1099            .into_response();
1100    };
1101
1102    if !version.can_collaborate() {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "client must be upgraded".to_string(),
1106        )
1107            .into_response();
1108    }
1109
1110    let socket_address = socket_address.to_string();
1111    ws.on_upgrade(move |socket| {
1112        let socket = socket
1113            .map_ok(to_tungstenite_message)
1114            .err_into()
1115            .with(|message| async move { to_axum_message(message) });
1116        let connection = Connection::new(Box::pin(socket));
1117        async move {
1118            server
1119                .handle_connection(
1120                    connection,
1121                    socket_address,
1122                    principal,
1123                    version,
1124                    country_code_header.map(|header| header.to_string()),
1125                    system_id_header.map(|header| header.to_string()),
1126                    None,
1127                    Executor::Production,
1128                )
1129                .await;
1130        }
1131    })
1132}
1133
1134pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1135    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136    let connections_metric = CONNECTIONS_METRIC
1137        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1138
1139    let connections = server
1140        .connection_pool
1141        .lock()
1142        .connections()
1143        .filter(|connection| !connection.admin)
1144        .count();
1145    connections_metric.set(connections as _);
1146
1147    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1149        register_int_gauge!(
1150            "shared_projects",
1151            "number of open projects with one or more guests"
1152        )
1153        .unwrap()
1154    });
1155
1156    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1157    shared_projects_metric.set(shared_projects as _);
1158
1159    let encoder = prometheus::TextEncoder::new();
1160    let metric_families = prometheus::gather();
1161    let encoded_metrics = encoder
1162        .encode_to_string(&metric_families)
1163        .map_err(|err| anyhow!("{}", err))?;
1164    Ok(encoded_metrics)
1165}
1166
1167#[instrument(err, skip(executor))]
1168async fn connection_lost(
1169    session: Session,
1170    mut teardown: watch::Receiver<bool>,
1171    executor: Executor,
1172) -> Result<()> {
1173    session.peer.disconnect(session.connection_id);
1174    session
1175        .connection_pool()
1176        .await
1177        .remove_connection(session.connection_id)?;
1178
1179    session
1180        .db()
1181        .await
1182        .connection_lost(session.connection_id)
1183        .await
1184        .trace_err();
1185
1186    futures::select_biased! {
1187        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1188
1189            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1190            leave_room_for_session(&session, session.connection_id).await.trace_err();
1191            leave_channel_buffers_for_session(&session)
1192                .await
1193                .trace_err();
1194
1195            if !session
1196                .connection_pool()
1197                .await
1198                .is_user_online(session.user_id())
1199            {
1200                let db = session.db().await;
1201                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1202                    room_updated(&room, &session.peer);
1203                }
1204            }
1205
1206            update_user_contacts(session.user_id(), &session).await?;
1207        },
1208        _ = teardown.changed().fuse() => {}
1209    }
1210
1211    Ok(())
1212}
1213
1214/// Acknowledges a ping from a client, used to keep the connection alive.
1215async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1216    response.send(proto::Ack {})?;
1217    Ok(())
1218}
1219
1220/// Creates a new room for calling (outside of channels)
1221async fn create_room(
1222    _request: proto::CreateRoom,
1223    response: Response<proto::CreateRoom>,
1224    session: Session,
1225) -> Result<()> {
1226    let livekit_room = nanoid::nanoid!(30);
1227
1228    let live_kit_connection_info = util::maybe!(async {
1229        let live_kit = session.app_state.livekit_client.as_ref();
1230        let live_kit = live_kit?;
1231        let user_id = session.user_id().to_string();
1232
1233        let token = live_kit
1234            .room_token(&livekit_room, &user_id.to_string())
1235            .trace_err()?;
1236
1237        Some(proto::LiveKitConnectionInfo {
1238            server_url: live_kit.url().into(),
1239            token,
1240            can_publish: true,
1241        })
1242    })
1243    .await;
1244
1245    let room = session
1246        .db()
1247        .await
1248        .create_room(session.user_id(), session.connection_id, &livekit_room)
1249        .await?;
1250
1251    response.send(proto::CreateRoomResponse {
1252        room: Some(room.clone()),
1253        live_kit_connection_info,
1254    })?;
1255
1256    update_user_contacts(session.user_id(), &session).await?;
1257    Ok(())
1258}
1259
1260/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1261async fn join_room(
1262    request: proto::JoinRoom,
1263    response: Response<proto::JoinRoom>,
1264    session: Session,
1265) -> Result<()> {
1266    let room_id = RoomId::from_proto(request.id);
1267
1268    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1269
1270    if let Some(channel_id) = channel_id {
1271        return join_channel_internal(channel_id, Box::new(response), session).await;
1272    }
1273
1274    let joined_room = {
1275        let room = session
1276            .db()
1277            .await
1278            .join_room(room_id, session.user_id(), session.connection_id)
1279            .await?;
1280        room_updated(&room.room, &session.peer);
1281        room.into_inner()
1282    };
1283
1284    for connection_id in session
1285        .connection_pool()
1286        .await
1287        .user_connection_ids(session.user_id())
1288    {
1289        session
1290            .peer
1291            .send(
1292                connection_id,
1293                proto::CallCanceled {
1294                    room_id: room_id.to_proto(),
1295                },
1296            )
1297            .trace_err();
1298    }
1299
1300    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1301    {
1302        live_kit
1303            .room_token(
1304                &joined_room.room.livekit_room,
1305                &session.user_id().to_string(),
1306            )
1307            .trace_err()
1308            .map(|token| proto::LiveKitConnectionInfo {
1309                server_url: live_kit.url().into(),
1310                token,
1311                can_publish: true,
1312            })
1313    } else {
1314        None
1315    };
1316
1317    response.send(proto::JoinRoomResponse {
1318        room: Some(joined_room.room),
1319        channel_id: None,
1320        live_kit_connection_info,
1321    })?;
1322
1323    update_user_contacts(session.user_id(), &session).await?;
1324    Ok(())
1325}
1326
1327/// Rejoin room is used to reconnect to a room after connection errors.
1328async fn rejoin_room(
1329    request: proto::RejoinRoom,
1330    response: Response<proto::RejoinRoom>,
1331    session: Session,
1332) -> Result<()> {
1333    let room;
1334    let channel;
1335    {
1336        let mut rejoined_room = session
1337            .db()
1338            .await
1339            .rejoin_room(request, session.user_id(), session.connection_id)
1340            .await?;
1341
1342        response.send(proto::RejoinRoomResponse {
1343            room: Some(rejoined_room.room.clone()),
1344            reshared_projects: rejoined_room
1345                .reshared_projects
1346                .iter()
1347                .map(|project| proto::ResharedProject {
1348                    id: project.id.to_proto(),
1349                    collaborators: project
1350                        .collaborators
1351                        .iter()
1352                        .map(|collaborator| collaborator.to_proto())
1353                        .collect(),
1354                })
1355                .collect(),
1356            rejoined_projects: rejoined_room
1357                .rejoined_projects
1358                .iter()
1359                .map(|rejoined_project| rejoined_project.to_proto())
1360                .collect(),
1361        })?;
1362        room_updated(&rejoined_room.room, &session.peer);
1363
1364        for project in &rejoined_room.reshared_projects {
1365            for collaborator in &project.collaborators {
1366                session
1367                    .peer
1368                    .send(
1369                        collaborator.connection_id,
1370                        proto::UpdateProjectCollaborator {
1371                            project_id: project.id.to_proto(),
1372                            old_peer_id: Some(project.old_connection_id.into()),
1373                            new_peer_id: Some(session.connection_id.into()),
1374                        },
1375                    )
1376                    .trace_err();
1377            }
1378
1379            broadcast(
1380                Some(session.connection_id),
1381                project
1382                    .collaborators
1383                    .iter()
1384                    .map(|collaborator| collaborator.connection_id),
1385                |connection_id| {
1386                    session.peer.forward_send(
1387                        session.connection_id,
1388                        connection_id,
1389                        proto::UpdateProject {
1390                            project_id: project.id.to_proto(),
1391                            worktrees: project.worktrees.clone(),
1392                        },
1393                    )
1394                },
1395            );
1396        }
1397
1398        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1399
1400        let rejoined_room = rejoined_room.into_inner();
1401
1402        room = rejoined_room.room;
1403        channel = rejoined_room.channel;
1404    }
1405
1406    if let Some(channel) = channel {
1407        channel_updated(
1408            &channel,
1409            &room,
1410            &session.peer,
1411            &*session.connection_pool().await,
1412        );
1413    }
1414
1415    update_user_contacts(session.user_id(), &session).await?;
1416    Ok(())
1417}
1418
1419fn notify_rejoined_projects(
1420    rejoined_projects: &mut Vec<RejoinedProject>,
1421    session: &Session,
1422) -> Result<()> {
1423    for project in rejoined_projects.iter() {
1424        for collaborator in &project.collaborators {
1425            session
1426                .peer
1427                .send(
1428                    collaborator.connection_id,
1429                    proto::UpdateProjectCollaborator {
1430                        project_id: project.id.to_proto(),
1431                        old_peer_id: Some(project.old_connection_id.into()),
1432                        new_peer_id: Some(session.connection_id.into()),
1433                    },
1434                )
1435                .trace_err();
1436        }
1437    }
1438
1439    for project in rejoined_projects {
1440        for worktree in mem::take(&mut project.worktrees) {
1441            // Stream this worktree's entries.
1442            let message = proto::UpdateWorktree {
1443                project_id: project.id.to_proto(),
1444                worktree_id: worktree.id,
1445                abs_path: worktree.abs_path.clone(),
1446                root_name: worktree.root_name,
1447                updated_entries: worktree.updated_entries,
1448                removed_entries: worktree.removed_entries,
1449                scan_id: worktree.scan_id,
1450                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1451                updated_repositories: worktree.updated_repositories,
1452                removed_repositories: worktree.removed_repositories,
1453            };
1454            for update in proto::split_worktree_update(message) {
1455                session.peer.send(session.connection_id, update)?;
1456            }
1457
1458            // Stream this worktree's diagnostics.
1459            for summary in worktree.diagnostic_summaries {
1460                session.peer.send(
1461                    session.connection_id,
1462                    proto::UpdateDiagnosticSummary {
1463                        project_id: project.id.to_proto(),
1464                        worktree_id: worktree.id,
1465                        summary: Some(summary),
1466                    },
1467                )?;
1468            }
1469
1470            for settings_file in worktree.settings_files {
1471                session.peer.send(
1472                    session.connection_id,
1473                    proto::UpdateWorktreeSettings {
1474                        project_id: project.id.to_proto(),
1475                        worktree_id: worktree.id,
1476                        path: settings_file.path,
1477                        content: Some(settings_file.content),
1478                        kind: Some(settings_file.kind.to_proto().into()),
1479                    },
1480                )?;
1481            }
1482        }
1483
1484        for repository in mem::take(&mut project.updated_repositories) {
1485            for update in split_repository_update(repository) {
1486                session.peer.send(session.connection_id, update)?;
1487            }
1488        }
1489
1490        for id in mem::take(&mut project.removed_repositories) {
1491            session.peer.send(
1492                session.connection_id,
1493                proto::RemoveRepository {
1494                    project_id: project.id.to_proto(),
1495                    id,
1496                },
1497            )?;
1498        }
1499    }
1500
1501    Ok(())
1502}
1503
1504/// leave room disconnects from the room.
1505async fn leave_room(
1506    _: proto::LeaveRoom,
1507    response: Response<proto::LeaveRoom>,
1508    session: Session,
1509) -> Result<()> {
1510    leave_room_for_session(&session, session.connection_id).await?;
1511    response.send(proto::Ack {})?;
1512    Ok(())
1513}
1514
1515/// Updates the permissions of someone else in the room.
1516async fn set_room_participant_role(
1517    request: proto::SetRoomParticipantRole,
1518    response: Response<proto::SetRoomParticipantRole>,
1519    session: Session,
1520) -> Result<()> {
1521    let user_id = UserId::from_proto(request.user_id);
1522    let role = ChannelRole::from(request.role());
1523
1524    let (livekit_room, can_publish) = {
1525        let room = session
1526            .db()
1527            .await
1528            .set_room_participant_role(
1529                session.user_id(),
1530                RoomId::from_proto(request.room_id),
1531                user_id,
1532                role,
1533            )
1534            .await?;
1535
1536        let livekit_room = room.livekit_room.clone();
1537        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1538        room_updated(&room, &session.peer);
1539        (livekit_room, can_publish)
1540    };
1541
1542    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1543        live_kit
1544            .update_participant(
1545                livekit_room.clone(),
1546                request.user_id.to_string(),
1547                livekit_api::proto::ParticipantPermission {
1548                    can_subscribe: true,
1549                    can_publish,
1550                    can_publish_data: can_publish,
1551                    hidden: false,
1552                    recorder: false,
1553                },
1554            )
1555            .await
1556            .trace_err();
1557    }
1558
1559    response.send(proto::Ack {})?;
1560    Ok(())
1561}
1562
1563/// Call someone else into the current room
1564async fn call(
1565    request: proto::Call,
1566    response: Response<proto::Call>,
1567    session: Session,
1568) -> Result<()> {
1569    let room_id = RoomId::from_proto(request.room_id);
1570    let calling_user_id = session.user_id();
1571    let calling_connection_id = session.connection_id;
1572    let called_user_id = UserId::from_proto(request.called_user_id);
1573    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1574    if !session
1575        .db()
1576        .await
1577        .has_contact(calling_user_id, called_user_id)
1578        .await?
1579    {
1580        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1581    }
1582
1583    let incoming_call = {
1584        let (room, incoming_call) = &mut *session
1585            .db()
1586            .await
1587            .call(
1588                room_id,
1589                calling_user_id,
1590                calling_connection_id,
1591                called_user_id,
1592                initial_project_id,
1593            )
1594            .await?;
1595        room_updated(room, &session.peer);
1596        mem::take(incoming_call)
1597    };
1598    update_user_contacts(called_user_id, &session).await?;
1599
1600    let mut calls = session
1601        .connection_pool()
1602        .await
1603        .user_connection_ids(called_user_id)
1604        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1605        .collect::<FuturesUnordered<_>>();
1606
1607    while let Some(call_response) = calls.next().await {
1608        match call_response.as_ref() {
1609            Ok(_) => {
1610                response.send(proto::Ack {})?;
1611                return Ok(());
1612            }
1613            Err(_) => {
1614                call_response.trace_err();
1615            }
1616        }
1617    }
1618
1619    {
1620        let room = session
1621            .db()
1622            .await
1623            .call_failed(room_id, called_user_id)
1624            .await?;
1625        room_updated(&room, &session.peer);
1626    }
1627    update_user_contacts(called_user_id, &session).await?;
1628
1629    Err(anyhow!("failed to ring user"))?
1630}
1631
1632/// Cancel an outgoing call.
1633async fn cancel_call(
1634    request: proto::CancelCall,
1635    response: Response<proto::CancelCall>,
1636    session: Session,
1637) -> Result<()> {
1638    let called_user_id = UserId::from_proto(request.called_user_id);
1639    let room_id = RoomId::from_proto(request.room_id);
1640    {
1641        let room = session
1642            .db()
1643            .await
1644            .cancel_call(room_id, session.connection_id, called_user_id)
1645            .await?;
1646        room_updated(&room, &session.peer);
1647    }
1648
1649    for connection_id in session
1650        .connection_pool()
1651        .await
1652        .user_connection_ids(called_user_id)
1653    {
1654        session
1655            .peer
1656            .send(
1657                connection_id,
1658                proto::CallCanceled {
1659                    room_id: room_id.to_proto(),
1660                },
1661            )
1662            .trace_err();
1663    }
1664    response.send(proto::Ack {})?;
1665
1666    update_user_contacts(called_user_id, &session).await?;
1667    Ok(())
1668}
1669
1670/// Decline an incoming call.
1671async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1672    let room_id = RoomId::from_proto(message.room_id);
1673    {
1674        let room = session
1675            .db()
1676            .await
1677            .decline_call(Some(room_id), session.user_id())
1678            .await?
1679            .ok_or_else(|| anyhow!("failed to decline call"))?;
1680        room_updated(&room, &session.peer);
1681    }
1682
1683    for connection_id in session
1684        .connection_pool()
1685        .await
1686        .user_connection_ids(session.user_id())
1687    {
1688        session
1689            .peer
1690            .send(
1691                connection_id,
1692                proto::CallCanceled {
1693                    room_id: room_id.to_proto(),
1694                },
1695            )
1696            .trace_err();
1697    }
1698    update_user_contacts(session.user_id(), &session).await?;
1699    Ok(())
1700}
1701
1702/// Updates other participants in the room with your current location.
1703async fn update_participant_location(
1704    request: proto::UpdateParticipantLocation,
1705    response: Response<proto::UpdateParticipantLocation>,
1706    session: Session,
1707) -> Result<()> {
1708    let room_id = RoomId::from_proto(request.room_id);
1709    let location = request
1710        .location
1711        .ok_or_else(|| anyhow!("invalid location"))?;
1712
1713    let db = session.db().await;
1714    let room = db
1715        .update_room_participant_location(room_id, session.connection_id, location)
1716        .await?;
1717
1718    room_updated(&room, &session.peer);
1719    response.send(proto::Ack {})?;
1720    Ok(())
1721}
1722
1723/// Share a project into the room.
1724async fn share_project(
1725    request: proto::ShareProject,
1726    response: Response<proto::ShareProject>,
1727    session: Session,
1728) -> Result<()> {
1729    let (project_id, room) = &*session
1730        .db()
1731        .await
1732        .share_project(
1733            RoomId::from_proto(request.room_id),
1734            session.connection_id,
1735            &request.worktrees,
1736            request.is_ssh_project,
1737        )
1738        .await?;
1739    response.send(proto::ShareProjectResponse {
1740        project_id: project_id.to_proto(),
1741    })?;
1742    room_updated(room, &session.peer);
1743
1744    Ok(())
1745}
1746
1747/// Unshare a project from the room.
1748async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1749    let project_id = ProjectId::from_proto(message.project_id);
1750    unshare_project_internal(project_id, session.connection_id, &session).await
1751}
1752
1753async fn unshare_project_internal(
1754    project_id: ProjectId,
1755    connection_id: ConnectionId,
1756    session: &Session,
1757) -> Result<()> {
1758    let delete = {
1759        let room_guard = session
1760            .db()
1761            .await
1762            .unshare_project(project_id, connection_id)
1763            .await?;
1764
1765        let (delete, room, guest_connection_ids) = &*room_guard;
1766
1767        let message = proto::UnshareProject {
1768            project_id: project_id.to_proto(),
1769        };
1770
1771        broadcast(
1772            Some(connection_id),
1773            guest_connection_ids.iter().copied(),
1774            |conn_id| session.peer.send(conn_id, message.clone()),
1775        );
1776        if let Some(room) = room {
1777            room_updated(room, &session.peer);
1778        }
1779
1780        *delete
1781    };
1782
1783    if delete {
1784        let db = session.db().await;
1785        db.delete_project(project_id).await?;
1786    }
1787
1788    Ok(())
1789}
1790
1791/// Join someone elses shared project.
1792async fn join_project(
1793    request: proto::JoinProject,
1794    response: Response<proto::JoinProject>,
1795    session: Session,
1796) -> Result<()> {
1797    let project_id = ProjectId::from_proto(request.project_id);
1798
1799    tracing::info!(%project_id, "join project");
1800
1801    let db = session.db().await;
1802    let (project, replica_id) = &mut *db
1803        .join_project(project_id, session.connection_id, session.user_id())
1804        .await?;
1805    drop(db);
1806    tracing::info!(%project_id, "join remote project");
1807    join_project_internal(response, session, project, replica_id)
1808}
1809
1810trait JoinProjectInternalResponse {
1811    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1814    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815        Response::<proto::JoinProject>::send(self, result)
1816    }
1817}
1818
1819fn join_project_internal(
1820    response: impl JoinProjectInternalResponse,
1821    session: Session,
1822    project: &mut Project,
1823    replica_id: &ReplicaId,
1824) -> Result<()> {
1825    let collaborators = project
1826        .collaborators
1827        .iter()
1828        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829        .map(|collaborator| collaborator.to_proto())
1830        .collect::<Vec<_>>();
1831    let project_id = project.id;
1832    let guest_user_id = session.user_id();
1833
1834    let worktrees = project
1835        .worktrees
1836        .iter()
1837        .map(|(id, worktree)| proto::WorktreeMetadata {
1838            id: *id,
1839            root_name: worktree.root_name.clone(),
1840            visible: worktree.visible,
1841            abs_path: worktree.abs_path.clone(),
1842        })
1843        .collect::<Vec<_>>();
1844
1845    let add_project_collaborator = proto::AddProjectCollaborator {
1846        project_id: project_id.to_proto(),
1847        collaborator: Some(proto::Collaborator {
1848            peer_id: Some(session.connection_id.into()),
1849            replica_id: replica_id.0 as u32,
1850            user_id: guest_user_id.to_proto(),
1851            is_host: false,
1852        }),
1853    };
1854
1855    for collaborator in &collaborators {
1856        session
1857            .peer
1858            .send(
1859                collaborator.peer_id.unwrap().into(),
1860                add_project_collaborator.clone(),
1861            )
1862            .trace_err();
1863    }
1864
1865    // First, we send the metadata associated with each worktree.
1866    response.send(proto::JoinProjectResponse {
1867        project_id: project.id.0 as u64,
1868        worktrees: worktrees.clone(),
1869        replica_id: replica_id.0 as u32,
1870        collaborators: collaborators.clone(),
1871        language_servers: project.language_servers.clone(),
1872        role: project.role.into(),
1873    })?;
1874
1875    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1876        // Stream this worktree's entries.
1877        let message = proto::UpdateWorktree {
1878            project_id: project_id.to_proto(),
1879            worktree_id,
1880            abs_path: worktree.abs_path.clone(),
1881            root_name: worktree.root_name,
1882            updated_entries: worktree.entries,
1883            removed_entries: Default::default(),
1884            scan_id: worktree.scan_id,
1885            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1887            removed_repositories: Default::default(),
1888        };
1889        for update in proto::split_worktree_update(message) {
1890            session.peer.send(session.connection_id, update.clone())?;
1891        }
1892
1893        // Stream this worktree's diagnostics.
1894        for summary in worktree.diagnostic_summaries {
1895            session.peer.send(
1896                session.connection_id,
1897                proto::UpdateDiagnosticSummary {
1898                    project_id: project_id.to_proto(),
1899                    worktree_id: worktree.id,
1900                    summary: Some(summary),
1901                },
1902            )?;
1903        }
1904
1905        for settings_file in worktree.settings_files {
1906            session.peer.send(
1907                session.connection_id,
1908                proto::UpdateWorktreeSettings {
1909                    project_id: project_id.to_proto(),
1910                    worktree_id: worktree.id,
1911                    path: settings_file.path,
1912                    content: Some(settings_file.content),
1913                    kind: Some(settings_file.kind.to_proto() as i32),
1914                },
1915            )?;
1916        }
1917    }
1918
1919    for repository in mem::take(&mut project.repositories) {
1920        for update in split_repository_update(repository) {
1921            session.peer.send(session.connection_id, update)?;
1922        }
1923    }
1924
1925    for language_server in &project.language_servers {
1926        session.peer.send(
1927            session.connection_id,
1928            proto::UpdateLanguageServer {
1929                project_id: project_id.to_proto(),
1930                language_server_id: language_server.id,
1931                variant: Some(
1932                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1933                        proto::LspDiskBasedDiagnosticsUpdated {},
1934                    ),
1935                ),
1936            },
1937        )?;
1938    }
1939
1940    Ok(())
1941}
1942
1943/// Leave someone elses shared project.
1944async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1945    let sender_id = session.connection_id;
1946    let project_id = ProjectId::from_proto(request.project_id);
1947    let db = session.db().await;
1948
1949    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950    tracing::info!(
1951        %project_id,
1952        "leave project"
1953    );
1954
1955    project_left(project, &session);
1956    if let Some(room) = room {
1957        room_updated(room, &session.peer);
1958    }
1959
1960    Ok(())
1961}
1962
1963/// Updates other participants with changes to the project
1964async fn update_project(
1965    request: proto::UpdateProject,
1966    response: Response<proto::UpdateProject>,
1967    session: Session,
1968) -> Result<()> {
1969    let project_id = ProjectId::from_proto(request.project_id);
1970    let (room, guest_connection_ids) = &*session
1971        .db()
1972        .await
1973        .update_project(project_id, session.connection_id, &request.worktrees)
1974        .await?;
1975    broadcast(
1976        Some(session.connection_id),
1977        guest_connection_ids.iter().copied(),
1978        |connection_id| {
1979            session
1980                .peer
1981                .forward_send(session.connection_id, connection_id, request.clone())
1982        },
1983    );
1984    if let Some(room) = room {
1985        room_updated(room, &session.peer);
1986    }
1987    response.send(proto::Ack {})?;
1988
1989    Ok(())
1990}
1991
1992/// Updates other participants with changes to the worktree
1993async fn update_worktree(
1994    request: proto::UpdateWorktree,
1995    response: Response<proto::UpdateWorktree>,
1996    session: Session,
1997) -> Result<()> {
1998    let guest_connection_ids = session
1999        .db()
2000        .await
2001        .update_worktree(&request, session.connection_id)
2002        .await?;
2003
2004    broadcast(
2005        Some(session.connection_id),
2006        guest_connection_ids.iter().copied(),
2007        |connection_id| {
2008            session
2009                .peer
2010                .forward_send(session.connection_id, connection_id, request.clone())
2011        },
2012    );
2013    response.send(proto::Ack {})?;
2014    Ok(())
2015}
2016
2017async fn update_repository(
2018    request: proto::UpdateRepository,
2019    response: Response<proto::UpdateRepository>,
2020    session: Session,
2021) -> Result<()> {
2022    let guest_connection_ids = session
2023        .db()
2024        .await
2025        .update_repository(&request, session.connection_id)
2026        .await?;
2027
2028    broadcast(
2029        Some(session.connection_id),
2030        guest_connection_ids.iter().copied(),
2031        |connection_id| {
2032            session
2033                .peer
2034                .forward_send(session.connection_id, connection_id, request.clone())
2035        },
2036    );
2037    response.send(proto::Ack {})?;
2038    Ok(())
2039}
2040
2041async fn remove_repository(
2042    request: proto::RemoveRepository,
2043    response: Response<proto::RemoveRepository>,
2044    session: Session,
2045) -> Result<()> {
2046    let guest_connection_ids = session
2047        .db()
2048        .await
2049        .remove_repository(&request, session.connection_id)
2050        .await?;
2051
2052    broadcast(
2053        Some(session.connection_id),
2054        guest_connection_ids.iter().copied(),
2055        |connection_id| {
2056            session
2057                .peer
2058                .forward_send(session.connection_id, connection_id, request.clone())
2059        },
2060    );
2061    response.send(proto::Ack {})?;
2062    Ok(())
2063}
2064
2065/// Updates other participants with changes to the diagnostics
2066async fn update_diagnostic_summary(
2067    message: proto::UpdateDiagnosticSummary,
2068    session: Session,
2069) -> Result<()> {
2070    let guest_connection_ids = session
2071        .db()
2072        .await
2073        .update_diagnostic_summary(&message, session.connection_id)
2074        .await?;
2075
2076    broadcast(
2077        Some(session.connection_id),
2078        guest_connection_ids.iter().copied(),
2079        |connection_id| {
2080            session
2081                .peer
2082                .forward_send(session.connection_id, connection_id, message.clone())
2083        },
2084    );
2085
2086    Ok(())
2087}
2088
2089/// Updates other participants with changes to the worktree settings
2090async fn update_worktree_settings(
2091    message: proto::UpdateWorktreeSettings,
2092    session: Session,
2093) -> Result<()> {
2094    let guest_connection_ids = session
2095        .db()
2096        .await
2097        .update_worktree_settings(&message, session.connection_id)
2098        .await?;
2099
2100    broadcast(
2101        Some(session.connection_id),
2102        guest_connection_ids.iter().copied(),
2103        |connection_id| {
2104            session
2105                .peer
2106                .forward_send(session.connection_id, connection_id, message.clone())
2107        },
2108    );
2109
2110    Ok(())
2111}
2112
2113/// Notify other participants that a language server has started.
2114async fn start_language_server(
2115    request: proto::StartLanguageServer,
2116    session: Session,
2117) -> Result<()> {
2118    let guest_connection_ids = session
2119        .db()
2120        .await
2121        .start_language_server(&request, session.connection_id)
2122        .await?;
2123
2124    broadcast(
2125        Some(session.connection_id),
2126        guest_connection_ids.iter().copied(),
2127        |connection_id| {
2128            session
2129                .peer
2130                .forward_send(session.connection_id, connection_id, request.clone())
2131        },
2132    );
2133    Ok(())
2134}
2135
2136/// Notify other participants that a language server has changed.
2137async fn update_language_server(
2138    request: proto::UpdateLanguageServer,
2139    session: Session,
2140) -> Result<()> {
2141    let project_id = ProjectId::from_proto(request.project_id);
2142    let project_connection_ids = session
2143        .db()
2144        .await
2145        .project_connection_ids(project_id, session.connection_id, true)
2146        .await?;
2147    broadcast(
2148        Some(session.connection_id),
2149        project_connection_ids.iter().copied(),
2150        |connection_id| {
2151            session
2152                .peer
2153                .forward_send(session.connection_id, connection_id, request.clone())
2154        },
2155    );
2156    Ok(())
2157}
2158
2159/// forward a project request to the host. These requests should be read only
2160/// as guests are allowed to send them.
2161async fn forward_read_only_project_request<T>(
2162    request: T,
2163    response: Response<T>,
2164    session: Session,
2165) -> Result<()>
2166where
2167    T: EntityMessage + RequestMessage,
2168{
2169    let project_id = ProjectId::from_proto(request.remote_entity_id());
2170    let host_connection_id = session
2171        .db()
2172        .await
2173        .host_for_read_only_project_request(project_id, session.connection_id)
2174        .await?;
2175    let payload = session
2176        .peer
2177        .forward_request(session.connection_id, host_connection_id, request)
2178        .await?;
2179    response.send(payload)?;
2180    Ok(())
2181}
2182
2183async fn forward_find_search_candidates_request(
2184    request: proto::FindSearchCandidates,
2185    response: Response<proto::FindSearchCandidates>,
2186    session: Session,
2187) -> Result<()> {
2188    let project_id = ProjectId::from_proto(request.remote_entity_id());
2189    let host_connection_id = session
2190        .db()
2191        .await
2192        .host_for_read_only_project_request(project_id, session.connection_id)
2193        .await?;
2194    let payload = session
2195        .peer
2196        .forward_request(session.connection_id, host_connection_id, request)
2197        .await?;
2198    response.send(payload)?;
2199    Ok(())
2200}
2201
2202/// forward a project request to the host. These requests are disallowed
2203/// for guests.
2204async fn forward_mutating_project_request<T>(
2205    request: T,
2206    response: Response<T>,
2207    session: Session,
2208) -> Result<()>
2209where
2210    T: EntityMessage + RequestMessage,
2211{
2212    let project_id = ProjectId::from_proto(request.remote_entity_id());
2213
2214    let host_connection_id = session
2215        .db()
2216        .await
2217        .host_for_mutating_project_request(project_id, session.connection_id)
2218        .await?;
2219    let payload = session
2220        .peer
2221        .forward_request(session.connection_id, host_connection_id, request)
2222        .await?;
2223    response.send(payload)?;
2224    Ok(())
2225}
2226
2227/// Notify other participants that a new buffer has been created
2228async fn create_buffer_for_peer(
2229    request: proto::CreateBufferForPeer,
2230    session: Session,
2231) -> Result<()> {
2232    session
2233        .db()
2234        .await
2235        .check_user_is_project_host(
2236            ProjectId::from_proto(request.project_id),
2237            session.connection_id,
2238        )
2239        .await?;
2240    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2241    session
2242        .peer
2243        .forward_send(session.connection_id, peer_id.into(), request)?;
2244    Ok(())
2245}
2246
2247/// Notify other participants that a buffer has been updated. This is
2248/// allowed for guests as long as the update is limited to selections.
2249async fn update_buffer(
2250    request: proto::UpdateBuffer,
2251    response: Response<proto::UpdateBuffer>,
2252    session: Session,
2253) -> Result<()> {
2254    let project_id = ProjectId::from_proto(request.project_id);
2255    let mut capability = Capability::ReadOnly;
2256
2257    for op in request.operations.iter() {
2258        match op.variant {
2259            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2260            Some(_) => capability = Capability::ReadWrite,
2261        }
2262    }
2263
2264    let host = {
2265        let guard = session
2266            .db()
2267            .await
2268            .connections_for_buffer_update(project_id, session.connection_id, capability)
2269            .await?;
2270
2271        let (host, guests) = &*guard;
2272
2273        broadcast(
2274            Some(session.connection_id),
2275            guests.clone(),
2276            |connection_id| {
2277                session
2278                    .peer
2279                    .forward_send(session.connection_id, connection_id, request.clone())
2280            },
2281        );
2282
2283        *host
2284    };
2285
2286    if host != session.connection_id {
2287        session
2288            .peer
2289            .forward_request(session.connection_id, host, request.clone())
2290            .await?;
2291    }
2292
2293    response.send(proto::Ack {})?;
2294    Ok(())
2295}
2296
2297async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2298    let project_id = ProjectId::from_proto(message.project_id);
2299
2300    let operation = message.operation.as_ref().context("invalid operation")?;
2301    let capability = match operation.variant.as_ref() {
2302        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2303            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2304                match buffer_op.variant {
2305                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2306                        Capability::ReadOnly
2307                    }
2308                    _ => Capability::ReadWrite,
2309                }
2310            } else {
2311                Capability::ReadWrite
2312            }
2313        }
2314        Some(_) => Capability::ReadWrite,
2315        None => Capability::ReadOnly,
2316    };
2317
2318    let guard = session
2319        .db()
2320        .await
2321        .connections_for_buffer_update(project_id, session.connection_id, capability)
2322        .await?;
2323
2324    let (host, guests) = &*guard;
2325
2326    broadcast(
2327        Some(session.connection_id),
2328        guests.iter().chain([host]).copied(),
2329        |connection_id| {
2330            session
2331                .peer
2332                .forward_send(session.connection_id, connection_id, message.clone())
2333        },
2334    );
2335
2336    Ok(())
2337}
2338
2339/// Notify other participants that a project has been updated.
2340async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2341    request: T,
2342    session: Session,
2343) -> Result<()> {
2344    let project_id = ProjectId::from_proto(request.remote_entity_id());
2345    let project_connection_ids = session
2346        .db()
2347        .await
2348        .project_connection_ids(project_id, session.connection_id, false)
2349        .await?;
2350
2351    broadcast(
2352        Some(session.connection_id),
2353        project_connection_ids.iter().copied(),
2354        |connection_id| {
2355            session
2356                .peer
2357                .forward_send(session.connection_id, connection_id, request.clone())
2358        },
2359    );
2360    Ok(())
2361}
2362
2363/// Start following another user in a call.
2364async fn follow(
2365    request: proto::Follow,
2366    response: Response<proto::Follow>,
2367    session: Session,
2368) -> Result<()> {
2369    let room_id = RoomId::from_proto(request.room_id);
2370    let project_id = request.project_id.map(ProjectId::from_proto);
2371    let leader_id = request
2372        .leader_id
2373        .ok_or_else(|| anyhow!("invalid leader id"))?
2374        .into();
2375    let follower_id = session.connection_id;
2376
2377    session
2378        .db()
2379        .await
2380        .check_room_participants(room_id, leader_id, session.connection_id)
2381        .await?;
2382
2383    let response_payload = session
2384        .peer
2385        .forward_request(session.connection_id, leader_id, request)
2386        .await?;
2387    response.send(response_payload)?;
2388
2389    if let Some(project_id) = project_id {
2390        let room = session
2391            .db()
2392            .await
2393            .follow(room_id, project_id, leader_id, follower_id)
2394            .await?;
2395        room_updated(&room, &session.peer);
2396    }
2397
2398    Ok(())
2399}
2400
2401/// Stop following another user in a call.
2402async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2403    let room_id = RoomId::from_proto(request.room_id);
2404    let project_id = request.project_id.map(ProjectId::from_proto);
2405    let leader_id = request
2406        .leader_id
2407        .ok_or_else(|| anyhow!("invalid leader id"))?
2408        .into();
2409    let follower_id = session.connection_id;
2410
2411    session
2412        .db()
2413        .await
2414        .check_room_participants(room_id, leader_id, session.connection_id)
2415        .await?;
2416
2417    session
2418        .peer
2419        .forward_send(session.connection_id, leader_id, request)?;
2420
2421    if let Some(project_id) = project_id {
2422        let room = session
2423            .db()
2424            .await
2425            .unfollow(room_id, project_id, leader_id, follower_id)
2426            .await?;
2427        room_updated(&room, &session.peer);
2428    }
2429
2430    Ok(())
2431}
2432
2433/// Notify everyone following you of your current location.
2434async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2435    let room_id = RoomId::from_proto(request.room_id);
2436    let database = session.db.lock().await;
2437
2438    let connection_ids = if let Some(project_id) = request.project_id {
2439        let project_id = ProjectId::from_proto(project_id);
2440        database
2441            .project_connection_ids(project_id, session.connection_id, true)
2442            .await?
2443    } else {
2444        database
2445            .room_connection_ids(room_id, session.connection_id)
2446            .await?
2447    };
2448
2449    // For now, don't send view update messages back to that view's current leader.
2450    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2451        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2452        _ => None,
2453    });
2454
2455    for connection_id in connection_ids.iter().cloned() {
2456        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2457            session
2458                .peer
2459                .forward_send(session.connection_id, connection_id, request.clone())?;
2460        }
2461    }
2462    Ok(())
2463}
2464
2465/// Get public data about users.
2466async fn get_users(
2467    request: proto::GetUsers,
2468    response: Response<proto::GetUsers>,
2469    session: Session,
2470) -> Result<()> {
2471    let user_ids = request
2472        .user_ids
2473        .into_iter()
2474        .map(UserId::from_proto)
2475        .collect();
2476    let users = session
2477        .db()
2478        .await
2479        .get_users_by_ids(user_ids)
2480        .await?
2481        .into_iter()
2482        .map(|user| proto::User {
2483            id: user.id.to_proto(),
2484            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2485            github_login: user.github_login,
2486            email: user.email_address,
2487            name: user.name,
2488        })
2489        .collect();
2490    response.send(proto::UsersResponse { users })?;
2491    Ok(())
2492}
2493
2494/// Search for users (to invite) buy Github login
2495async fn fuzzy_search_users(
2496    request: proto::FuzzySearchUsers,
2497    response: Response<proto::FuzzySearchUsers>,
2498    session: Session,
2499) -> Result<()> {
2500    let query = request.query;
2501    let users = match query.len() {
2502        0 => vec![],
2503        1 | 2 => session
2504            .db()
2505            .await
2506            .get_user_by_github_login(&query)
2507            .await?
2508            .into_iter()
2509            .collect(),
2510        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2511    };
2512    let users = users
2513        .into_iter()
2514        .filter(|user| user.id != session.user_id())
2515        .map(|user| proto::User {
2516            id: user.id.to_proto(),
2517            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2518            github_login: user.github_login,
2519            name: user.name,
2520            email: user.email_address,
2521        })
2522        .collect();
2523    response.send(proto::UsersResponse { users })?;
2524    Ok(())
2525}
2526
2527/// Send a contact request to another user.
2528async fn request_contact(
2529    request: proto::RequestContact,
2530    response: Response<proto::RequestContact>,
2531    session: Session,
2532) -> Result<()> {
2533    let requester_id = session.user_id();
2534    let responder_id = UserId::from_proto(request.responder_id);
2535    if requester_id == responder_id {
2536        return Err(anyhow!("cannot add yourself as a contact"))?;
2537    }
2538
2539    let notifications = session
2540        .db()
2541        .await
2542        .send_contact_request(requester_id, responder_id)
2543        .await?;
2544
2545    // Update outgoing contact requests of requester
2546    let mut update = proto::UpdateContacts::default();
2547    update.outgoing_requests.push(responder_id.to_proto());
2548    for connection_id in session
2549        .connection_pool()
2550        .await
2551        .user_connection_ids(requester_id)
2552    {
2553        session.peer.send(connection_id, update.clone())?;
2554    }
2555
2556    // Update incoming contact requests of responder
2557    let mut update = proto::UpdateContacts::default();
2558    update
2559        .incoming_requests
2560        .push(proto::IncomingContactRequest {
2561            requester_id: requester_id.to_proto(),
2562        });
2563    let connection_pool = session.connection_pool().await;
2564    for connection_id in connection_pool.user_connection_ids(responder_id) {
2565        session.peer.send(connection_id, update.clone())?;
2566    }
2567
2568    send_notifications(&connection_pool, &session.peer, notifications);
2569
2570    response.send(proto::Ack {})?;
2571    Ok(())
2572}
2573
2574/// Accept or decline a contact request
2575async fn respond_to_contact_request(
2576    request: proto::RespondToContactRequest,
2577    response: Response<proto::RespondToContactRequest>,
2578    session: Session,
2579) -> Result<()> {
2580    let responder_id = session.user_id();
2581    let requester_id = UserId::from_proto(request.requester_id);
2582    let db = session.db().await;
2583    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2584        db.dismiss_contact_notification(responder_id, requester_id)
2585            .await?;
2586    } else {
2587        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2588
2589        let notifications = db
2590            .respond_to_contact_request(responder_id, requester_id, accept)
2591            .await?;
2592        let requester_busy = db.is_user_busy(requester_id).await?;
2593        let responder_busy = db.is_user_busy(responder_id).await?;
2594
2595        let pool = session.connection_pool().await;
2596        // Update responder with new contact
2597        let mut update = proto::UpdateContacts::default();
2598        if accept {
2599            update
2600                .contacts
2601                .push(contact_for_user(requester_id, requester_busy, &pool));
2602        }
2603        update
2604            .remove_incoming_requests
2605            .push(requester_id.to_proto());
2606        for connection_id in pool.user_connection_ids(responder_id) {
2607            session.peer.send(connection_id, update.clone())?;
2608        }
2609
2610        // Update requester with new contact
2611        let mut update = proto::UpdateContacts::default();
2612        if accept {
2613            update
2614                .contacts
2615                .push(contact_for_user(responder_id, responder_busy, &pool));
2616        }
2617        update
2618            .remove_outgoing_requests
2619            .push(responder_id.to_proto());
2620
2621        for connection_id in pool.user_connection_ids(requester_id) {
2622            session.peer.send(connection_id, update.clone())?;
2623        }
2624
2625        send_notifications(&pool, &session.peer, notifications);
2626    }
2627
2628    response.send(proto::Ack {})?;
2629    Ok(())
2630}
2631
2632/// Remove a contact.
2633async fn remove_contact(
2634    request: proto::RemoveContact,
2635    response: Response<proto::RemoveContact>,
2636    session: Session,
2637) -> Result<()> {
2638    let requester_id = session.user_id();
2639    let responder_id = UserId::from_proto(request.user_id);
2640    let db = session.db().await;
2641    let (contact_accepted, deleted_notification_id) =
2642        db.remove_contact(requester_id, responder_id).await?;
2643
2644    let pool = session.connection_pool().await;
2645    // Update outgoing contact requests of requester
2646    let mut update = proto::UpdateContacts::default();
2647    if contact_accepted {
2648        update.remove_contacts.push(responder_id.to_proto());
2649    } else {
2650        update
2651            .remove_outgoing_requests
2652            .push(responder_id.to_proto());
2653    }
2654    for connection_id in pool.user_connection_ids(requester_id) {
2655        session.peer.send(connection_id, update.clone())?;
2656    }
2657
2658    // Update incoming contact requests of responder
2659    let mut update = proto::UpdateContacts::default();
2660    if contact_accepted {
2661        update.remove_contacts.push(requester_id.to_proto());
2662    } else {
2663        update
2664            .remove_incoming_requests
2665            .push(requester_id.to_proto());
2666    }
2667    for connection_id in pool.user_connection_ids(responder_id) {
2668        session.peer.send(connection_id, update.clone())?;
2669        if let Some(notification_id) = deleted_notification_id {
2670            session.peer.send(
2671                connection_id,
2672                proto::DeleteNotification {
2673                    notification_id: notification_id.to_proto(),
2674                },
2675            )?;
2676        }
2677    }
2678
2679    response.send(proto::Ack {})?;
2680    Ok(())
2681}
2682
2683fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2684    version.0.minor() < 139
2685}
2686
2687async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2688    let plan = session.current_plan(&session.db().await).await?;
2689
2690    session
2691        .peer
2692        .send(
2693            session.connection_id,
2694            proto::UpdateUserPlan { plan: plan.into() },
2695        )
2696        .trace_err();
2697
2698    Ok(())
2699}
2700
2701async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2702    subscribe_user_to_channels(session.user_id(), &session).await?;
2703    Ok(())
2704}
2705
2706async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2707    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2708    let mut pool = session.connection_pool().await;
2709    for membership in &channels_for_user.channel_memberships {
2710        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2711    }
2712    session.peer.send(
2713        session.connection_id,
2714        build_update_user_channels(&channels_for_user),
2715    )?;
2716    session.peer.send(
2717        session.connection_id,
2718        build_channels_update(channels_for_user),
2719    )?;
2720    Ok(())
2721}
2722
2723/// Creates a new channel.
2724async fn create_channel(
2725    request: proto::CreateChannel,
2726    response: Response<proto::CreateChannel>,
2727    session: Session,
2728) -> Result<()> {
2729    let db = session.db().await;
2730
2731    let parent_id = request.parent_id.map(ChannelId::from_proto);
2732    let (channel, membership) = db
2733        .create_channel(&request.name, parent_id, session.user_id())
2734        .await?;
2735
2736    let root_id = channel.root_id();
2737    let channel = Channel::from_model(channel);
2738
2739    response.send(proto::CreateChannelResponse {
2740        channel: Some(channel.to_proto()),
2741        parent_id: request.parent_id,
2742    })?;
2743
2744    let mut connection_pool = session.connection_pool().await;
2745    if let Some(membership) = membership {
2746        connection_pool.subscribe_to_channel(
2747            membership.user_id,
2748            membership.channel_id,
2749            membership.role,
2750        );
2751        let update = proto::UpdateUserChannels {
2752            channel_memberships: vec![proto::ChannelMembership {
2753                channel_id: membership.channel_id.to_proto(),
2754                role: membership.role.into(),
2755            }],
2756            ..Default::default()
2757        };
2758        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2759            session.peer.send(connection_id, update.clone())?;
2760        }
2761    }
2762
2763    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2764        if !role.can_see_channel(channel.visibility) {
2765            continue;
2766        }
2767
2768        let update = proto::UpdateChannels {
2769            channels: vec![channel.to_proto()],
2770            ..Default::default()
2771        };
2772        session.peer.send(connection_id, update.clone())?;
2773    }
2774
2775    Ok(())
2776}
2777
2778/// Delete a channel
2779async fn delete_channel(
2780    request: proto::DeleteChannel,
2781    response: Response<proto::DeleteChannel>,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785
2786    let channel_id = request.channel_id;
2787    let (root_channel, removed_channels) = db
2788        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2789        .await?;
2790    response.send(proto::Ack {})?;
2791
2792    // Notify members of removed channels
2793    let mut update = proto::UpdateChannels::default();
2794    update
2795        .delete_channels
2796        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2797
2798    let connection_pool = session.connection_pool().await;
2799    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2800        session.peer.send(connection_id, update.clone())?;
2801    }
2802
2803    Ok(())
2804}
2805
2806/// Invite someone to join a channel.
2807async fn invite_channel_member(
2808    request: proto::InviteChannelMember,
2809    response: Response<proto::InviteChannelMember>,
2810    session: Session,
2811) -> Result<()> {
2812    let db = session.db().await;
2813    let channel_id = ChannelId::from_proto(request.channel_id);
2814    let invitee_id = UserId::from_proto(request.user_id);
2815    let InviteMemberResult {
2816        channel,
2817        notifications,
2818    } = db
2819        .invite_channel_member(
2820            channel_id,
2821            invitee_id,
2822            session.user_id(),
2823            request.role().into(),
2824        )
2825        .await?;
2826
2827    let update = proto::UpdateChannels {
2828        channel_invitations: vec![channel.to_proto()],
2829        ..Default::default()
2830    };
2831
2832    let connection_pool = session.connection_pool().await;
2833    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2834        session.peer.send(connection_id, update.clone())?;
2835    }
2836
2837    send_notifications(&connection_pool, &session.peer, notifications);
2838
2839    response.send(proto::Ack {})?;
2840    Ok(())
2841}
2842
2843/// remove someone from a channel
2844async fn remove_channel_member(
2845    request: proto::RemoveChannelMember,
2846    response: Response<proto::RemoveChannelMember>,
2847    session: Session,
2848) -> Result<()> {
2849    let db = session.db().await;
2850    let channel_id = ChannelId::from_proto(request.channel_id);
2851    let member_id = UserId::from_proto(request.user_id);
2852
2853    let RemoveChannelMemberResult {
2854        membership_update,
2855        notification_id,
2856    } = db
2857        .remove_channel_member(channel_id, member_id, session.user_id())
2858        .await?;
2859
2860    let mut connection_pool = session.connection_pool().await;
2861    notify_membership_updated(
2862        &mut connection_pool,
2863        membership_update,
2864        member_id,
2865        &session.peer,
2866    );
2867    for connection_id in connection_pool.user_connection_ids(member_id) {
2868        if let Some(notification_id) = notification_id {
2869            session
2870                .peer
2871                .send(
2872                    connection_id,
2873                    proto::DeleteNotification {
2874                        notification_id: notification_id.to_proto(),
2875                    },
2876                )
2877                .trace_err();
2878        }
2879    }
2880
2881    response.send(proto::Ack {})?;
2882    Ok(())
2883}
2884
2885/// Toggle the channel between public and private.
2886/// Care is taken to maintain the invariant that public channels only descend from public channels,
2887/// (though members-only channels can appear at any point in the hierarchy).
2888async fn set_channel_visibility(
2889    request: proto::SetChannelVisibility,
2890    response: Response<proto::SetChannelVisibility>,
2891    session: Session,
2892) -> Result<()> {
2893    let db = session.db().await;
2894    let channel_id = ChannelId::from_proto(request.channel_id);
2895    let visibility = request.visibility().into();
2896
2897    let channel_model = db
2898        .set_channel_visibility(channel_id, visibility, session.user_id())
2899        .await?;
2900    let root_id = channel_model.root_id();
2901    let channel = Channel::from_model(channel_model);
2902
2903    let mut connection_pool = session.connection_pool().await;
2904    for (user_id, role) in connection_pool
2905        .channel_user_ids(root_id)
2906        .collect::<Vec<_>>()
2907        .into_iter()
2908    {
2909        let update = if role.can_see_channel(channel.visibility) {
2910            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2911            proto::UpdateChannels {
2912                channels: vec![channel.to_proto()],
2913                ..Default::default()
2914            }
2915        } else {
2916            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2917            proto::UpdateChannels {
2918                delete_channels: vec![channel.id.to_proto()],
2919                ..Default::default()
2920            }
2921        };
2922
2923        for connection_id in connection_pool.user_connection_ids(user_id) {
2924            session.peer.send(connection_id, update.clone())?;
2925        }
2926    }
2927
2928    response.send(proto::Ack {})?;
2929    Ok(())
2930}
2931
2932/// Alter the role for a user in the channel.
2933async fn set_channel_member_role(
2934    request: proto::SetChannelMemberRole,
2935    response: Response<proto::SetChannelMemberRole>,
2936    session: Session,
2937) -> Result<()> {
2938    let db = session.db().await;
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940    let member_id = UserId::from_proto(request.user_id);
2941    let result = db
2942        .set_channel_member_role(
2943            channel_id,
2944            session.user_id(),
2945            member_id,
2946            request.role().into(),
2947        )
2948        .await?;
2949
2950    match result {
2951        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2952            let mut connection_pool = session.connection_pool().await;
2953            notify_membership_updated(
2954                &mut connection_pool,
2955                membership_update,
2956                member_id,
2957                &session.peer,
2958            )
2959        }
2960        db::SetMemberRoleResult::InviteUpdated(channel) => {
2961            let update = proto::UpdateChannels {
2962                channel_invitations: vec![channel.to_proto()],
2963                ..Default::default()
2964            };
2965
2966            for connection_id in session
2967                .connection_pool()
2968                .await
2969                .user_connection_ids(member_id)
2970            {
2971                session.peer.send(connection_id, update.clone())?;
2972            }
2973        }
2974    }
2975
2976    response.send(proto::Ack {})?;
2977    Ok(())
2978}
2979
2980/// Change the name of a channel
2981async fn rename_channel(
2982    request: proto::RenameChannel,
2983    response: Response<proto::RenameChannel>,
2984    session: Session,
2985) -> Result<()> {
2986    let db = session.db().await;
2987    let channel_id = ChannelId::from_proto(request.channel_id);
2988    let channel_model = db
2989        .rename_channel(channel_id, session.user_id(), &request.name)
2990        .await?;
2991    let root_id = channel_model.root_id();
2992    let channel = Channel::from_model(channel_model);
2993
2994    response.send(proto::RenameChannelResponse {
2995        channel: Some(channel.to_proto()),
2996    })?;
2997
2998    let connection_pool = session.connection_pool().await;
2999    let update = proto::UpdateChannels {
3000        channels: vec![channel.to_proto()],
3001        ..Default::default()
3002    };
3003    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3004        if role.can_see_channel(channel.visibility) {
3005            session.peer.send(connection_id, update.clone())?;
3006        }
3007    }
3008
3009    Ok(())
3010}
3011
3012/// Move a channel to a new parent.
3013async fn move_channel(
3014    request: proto::MoveChannel,
3015    response: Response<proto::MoveChannel>,
3016    session: Session,
3017) -> Result<()> {
3018    let channel_id = ChannelId::from_proto(request.channel_id);
3019    let to = ChannelId::from_proto(request.to);
3020
3021    let (root_id, channels) = session
3022        .db()
3023        .await
3024        .move_channel(channel_id, to, session.user_id())
3025        .await?;
3026
3027    let connection_pool = session.connection_pool().await;
3028    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3029        let channels = channels
3030            .iter()
3031            .filter_map(|channel| {
3032                if role.can_see_channel(channel.visibility) {
3033                    Some(channel.to_proto())
3034                } else {
3035                    None
3036                }
3037            })
3038            .collect::<Vec<_>>();
3039        if channels.is_empty() {
3040            continue;
3041        }
3042
3043        let update = proto::UpdateChannels {
3044            channels,
3045            ..Default::default()
3046        };
3047
3048        session.peer.send(connection_id, update.clone())?;
3049    }
3050
3051    response.send(Ack {})?;
3052    Ok(())
3053}
3054
3055/// Get the list of channel members
3056async fn get_channel_members(
3057    request: proto::GetChannelMembers,
3058    response: Response<proto::GetChannelMembers>,
3059    session: Session,
3060) -> Result<()> {
3061    let db = session.db().await;
3062    let channel_id = ChannelId::from_proto(request.channel_id);
3063    let limit = if request.limit == 0 {
3064        u16::MAX as u64
3065    } else {
3066        request.limit
3067    };
3068    let (members, users) = db
3069        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3070        .await?;
3071    response.send(proto::GetChannelMembersResponse { members, users })?;
3072    Ok(())
3073}
3074
3075/// Accept or decline a channel invitation.
3076async fn respond_to_channel_invite(
3077    request: proto::RespondToChannelInvite,
3078    response: Response<proto::RespondToChannelInvite>,
3079    session: Session,
3080) -> Result<()> {
3081    let db = session.db().await;
3082    let channel_id = ChannelId::from_proto(request.channel_id);
3083    let RespondToChannelInvite {
3084        membership_update,
3085        notifications,
3086    } = db
3087        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3088        .await?;
3089
3090    let mut connection_pool = session.connection_pool().await;
3091    if let Some(membership_update) = membership_update {
3092        notify_membership_updated(
3093            &mut connection_pool,
3094            membership_update,
3095            session.user_id(),
3096            &session.peer,
3097        );
3098    } else {
3099        let update = proto::UpdateChannels {
3100            remove_channel_invitations: vec![channel_id.to_proto()],
3101            ..Default::default()
3102        };
3103
3104        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3105            session.peer.send(connection_id, update.clone())?;
3106        }
3107    };
3108
3109    send_notifications(&connection_pool, &session.peer, notifications);
3110
3111    response.send(proto::Ack {})?;
3112
3113    Ok(())
3114}
3115
3116/// Join the channels' room
3117async fn join_channel(
3118    request: proto::JoinChannel,
3119    response: Response<proto::JoinChannel>,
3120    session: Session,
3121) -> Result<()> {
3122    let channel_id = ChannelId::from_proto(request.channel_id);
3123    join_channel_internal(channel_id, Box::new(response), session).await
3124}
3125
3126trait JoinChannelInternalResponse {
3127    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3128}
3129impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3130    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3131        Response::<proto::JoinChannel>::send(self, result)
3132    }
3133}
3134impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3135    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3136        Response::<proto::JoinRoom>::send(self, result)
3137    }
3138}
3139
3140async fn join_channel_internal(
3141    channel_id: ChannelId,
3142    response: Box<impl JoinChannelInternalResponse>,
3143    session: Session,
3144) -> Result<()> {
3145    let joined_room = {
3146        let mut db = session.db().await;
3147        // If zed quits without leaving the room, and the user re-opens zed before the
3148        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3149        // room they were in.
3150        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3151            tracing::info!(
3152                stale_connection_id = %connection,
3153                "cleaning up stale connection",
3154            );
3155            drop(db);
3156            leave_room_for_session(&session, connection).await?;
3157            db = session.db().await;
3158        }
3159
3160        let (joined_room, membership_updated, role) = db
3161            .join_channel(channel_id, session.user_id(), session.connection_id)
3162            .await?;
3163
3164        let live_kit_connection_info =
3165            session
3166                .app_state
3167                .livekit_client
3168                .as_ref()
3169                .and_then(|live_kit| {
3170                    let (can_publish, token) = if role == ChannelRole::Guest {
3171                        (
3172                            false,
3173                            live_kit
3174                                .guest_token(
3175                                    &joined_room.room.livekit_room,
3176                                    &session.user_id().to_string(),
3177                                )
3178                                .trace_err()?,
3179                        )
3180                    } else {
3181                        (
3182                            true,
3183                            live_kit
3184                                .room_token(
3185                                    &joined_room.room.livekit_room,
3186                                    &session.user_id().to_string(),
3187                                )
3188                                .trace_err()?,
3189                        )
3190                    };
3191
3192                    Some(LiveKitConnectionInfo {
3193                        server_url: live_kit.url().into(),
3194                        token,
3195                        can_publish,
3196                    })
3197                });
3198
3199        response.send(proto::JoinRoomResponse {
3200            room: Some(joined_room.room.clone()),
3201            channel_id: joined_room
3202                .channel
3203                .as_ref()
3204                .map(|channel| channel.id.to_proto()),
3205            live_kit_connection_info,
3206        })?;
3207
3208        let mut connection_pool = session.connection_pool().await;
3209        if let Some(membership_updated) = membership_updated {
3210            notify_membership_updated(
3211                &mut connection_pool,
3212                membership_updated,
3213                session.user_id(),
3214                &session.peer,
3215            );
3216        }
3217
3218        room_updated(&joined_room.room, &session.peer);
3219
3220        joined_room
3221    };
3222
3223    channel_updated(
3224        &joined_room
3225            .channel
3226            .ok_or_else(|| anyhow!("channel not returned"))?,
3227        &joined_room.room,
3228        &session.peer,
3229        &*session.connection_pool().await,
3230    );
3231
3232    update_user_contacts(session.user_id(), &session).await?;
3233    Ok(())
3234}
3235
3236/// Start editing the channel notes
3237async fn join_channel_buffer(
3238    request: proto::JoinChannelBuffer,
3239    response: Response<proto::JoinChannelBuffer>,
3240    session: Session,
3241) -> Result<()> {
3242    let db = session.db().await;
3243    let channel_id = ChannelId::from_proto(request.channel_id);
3244
3245    let open_response = db
3246        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3247        .await?;
3248
3249    let collaborators = open_response.collaborators.clone();
3250    response.send(open_response)?;
3251
3252    let update = UpdateChannelBufferCollaborators {
3253        channel_id: channel_id.to_proto(),
3254        collaborators: collaborators.clone(),
3255    };
3256    channel_buffer_updated(
3257        session.connection_id,
3258        collaborators
3259            .iter()
3260            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3261        &update,
3262        &session.peer,
3263    );
3264
3265    Ok(())
3266}
3267
3268/// Edit the channel notes
3269async fn update_channel_buffer(
3270    request: proto::UpdateChannelBuffer,
3271    session: Session,
3272) -> Result<()> {
3273    let db = session.db().await;
3274    let channel_id = ChannelId::from_proto(request.channel_id);
3275
3276    let (collaborators, epoch, version) = db
3277        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3278        .await?;
3279
3280    channel_buffer_updated(
3281        session.connection_id,
3282        collaborators.clone(),
3283        &proto::UpdateChannelBuffer {
3284            channel_id: channel_id.to_proto(),
3285            operations: request.operations,
3286        },
3287        &session.peer,
3288    );
3289
3290    let pool = &*session.connection_pool().await;
3291
3292    let non_collaborators =
3293        pool.channel_connection_ids(channel_id)
3294            .filter_map(|(connection_id, _)| {
3295                if collaborators.contains(&connection_id) {
3296                    None
3297                } else {
3298                    Some(connection_id)
3299                }
3300            });
3301
3302    broadcast(None, non_collaborators, |peer_id| {
3303        session.peer.send(
3304            peer_id,
3305            proto::UpdateChannels {
3306                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3307                    channel_id: channel_id.to_proto(),
3308                    epoch: epoch as u64,
3309                    version: version.clone(),
3310                }],
3311                ..Default::default()
3312            },
3313        )
3314    });
3315
3316    Ok(())
3317}
3318
3319/// Rejoin the channel notes after a connection blip
3320async fn rejoin_channel_buffers(
3321    request: proto::RejoinChannelBuffers,
3322    response: Response<proto::RejoinChannelBuffers>,
3323    session: Session,
3324) -> Result<()> {
3325    let db = session.db().await;
3326    let buffers = db
3327        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3328        .await?;
3329
3330    for rejoined_buffer in &buffers {
3331        let collaborators_to_notify = rejoined_buffer
3332            .buffer
3333            .collaborators
3334            .iter()
3335            .filter_map(|c| Some(c.peer_id?.into()));
3336        channel_buffer_updated(
3337            session.connection_id,
3338            collaborators_to_notify,
3339            &proto::UpdateChannelBufferCollaborators {
3340                channel_id: rejoined_buffer.buffer.channel_id,
3341                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3342            },
3343            &session.peer,
3344        );
3345    }
3346
3347    response.send(proto::RejoinChannelBuffersResponse {
3348        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3349    })?;
3350
3351    Ok(())
3352}
3353
3354/// Stop editing the channel notes
3355async fn leave_channel_buffer(
3356    request: proto::LeaveChannelBuffer,
3357    response: Response<proto::LeaveChannelBuffer>,
3358    session: Session,
3359) -> Result<()> {
3360    let db = session.db().await;
3361    let channel_id = ChannelId::from_proto(request.channel_id);
3362
3363    let left_buffer = db
3364        .leave_channel_buffer(channel_id, session.connection_id)
3365        .await?;
3366
3367    response.send(Ack {})?;
3368
3369    channel_buffer_updated(
3370        session.connection_id,
3371        left_buffer.connections,
3372        &proto::UpdateChannelBufferCollaborators {
3373            channel_id: channel_id.to_proto(),
3374            collaborators: left_buffer.collaborators,
3375        },
3376        &session.peer,
3377    );
3378
3379    Ok(())
3380}
3381
3382fn channel_buffer_updated<T: EnvelopedMessage>(
3383    sender_id: ConnectionId,
3384    collaborators: impl IntoIterator<Item = ConnectionId>,
3385    message: &T,
3386    peer: &Peer,
3387) {
3388    broadcast(Some(sender_id), collaborators, |peer_id| {
3389        peer.send(peer_id, message.clone())
3390    });
3391}
3392
3393fn send_notifications(
3394    connection_pool: &ConnectionPool,
3395    peer: &Peer,
3396    notifications: db::NotificationBatch,
3397) {
3398    for (user_id, notification) in notifications {
3399        for connection_id in connection_pool.user_connection_ids(user_id) {
3400            if let Err(error) = peer.send(
3401                connection_id,
3402                proto::AddNotification {
3403                    notification: Some(notification.clone()),
3404                },
3405            ) {
3406                tracing::error!(
3407                    "failed to send notification to {:?} {}",
3408                    connection_id,
3409                    error
3410                );
3411            }
3412        }
3413    }
3414}
3415
3416/// Send a message to the channel
3417async fn send_channel_message(
3418    request: proto::SendChannelMessage,
3419    response: Response<proto::SendChannelMessage>,
3420    session: Session,
3421) -> Result<()> {
3422    // Validate the message body.
3423    let body = request.body.trim().to_string();
3424    if body.len() > MAX_MESSAGE_LEN {
3425        return Err(anyhow!("message is too long"))?;
3426    }
3427    if body.is_empty() {
3428        return Err(anyhow!("message can't be blank"))?;
3429    }
3430
3431    // TODO: adjust mentions if body is trimmed
3432
3433    let timestamp = OffsetDateTime::now_utc();
3434    let nonce = request
3435        .nonce
3436        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3437
3438    let channel_id = ChannelId::from_proto(request.channel_id);
3439    let CreatedChannelMessage {
3440        message_id,
3441        participant_connection_ids,
3442        notifications,
3443    } = session
3444        .db()
3445        .await
3446        .create_channel_message(
3447            channel_id,
3448            session.user_id(),
3449            &body,
3450            &request.mentions,
3451            timestamp,
3452            nonce.clone().into(),
3453            request.reply_to_message_id.map(MessageId::from_proto),
3454        )
3455        .await?;
3456
3457    let message = proto::ChannelMessage {
3458        sender_id: session.user_id().to_proto(),
3459        id: message_id.to_proto(),
3460        body,
3461        mentions: request.mentions,
3462        timestamp: timestamp.unix_timestamp() as u64,
3463        nonce: Some(nonce),
3464        reply_to_message_id: request.reply_to_message_id,
3465        edited_at: None,
3466    };
3467    broadcast(
3468        Some(session.connection_id),
3469        participant_connection_ids.clone(),
3470        |connection| {
3471            session.peer.send(
3472                connection,
3473                proto::ChannelMessageSent {
3474                    channel_id: channel_id.to_proto(),
3475                    message: Some(message.clone()),
3476                },
3477            )
3478        },
3479    );
3480    response.send(proto::SendChannelMessageResponse {
3481        message: Some(message),
3482    })?;
3483
3484    let pool = &*session.connection_pool().await;
3485    let non_participants =
3486        pool.channel_connection_ids(channel_id)
3487            .filter_map(|(connection_id, _)| {
3488                if participant_connection_ids.contains(&connection_id) {
3489                    None
3490                } else {
3491                    Some(connection_id)
3492                }
3493            });
3494    broadcast(None, non_participants, |peer_id| {
3495        session.peer.send(
3496            peer_id,
3497            proto::UpdateChannels {
3498                latest_channel_message_ids: vec![proto::ChannelMessageId {
3499                    channel_id: channel_id.to_proto(),
3500                    message_id: message_id.to_proto(),
3501                }],
3502                ..Default::default()
3503            },
3504        )
3505    });
3506    send_notifications(pool, &session.peer, notifications);
3507
3508    Ok(())
3509}
3510
3511/// Delete a channel message
3512async fn remove_channel_message(
3513    request: proto::RemoveChannelMessage,
3514    response: Response<proto::RemoveChannelMessage>,
3515    session: Session,
3516) -> Result<()> {
3517    let channel_id = ChannelId::from_proto(request.channel_id);
3518    let message_id = MessageId::from_proto(request.message_id);
3519    let (connection_ids, existing_notification_ids) = session
3520        .db()
3521        .await
3522        .remove_channel_message(channel_id, message_id, session.user_id())
3523        .await?;
3524
3525    broadcast(
3526        Some(session.connection_id),
3527        connection_ids,
3528        move |connection| {
3529            session.peer.send(connection, request.clone())?;
3530
3531            for notification_id in &existing_notification_ids {
3532                session.peer.send(
3533                    connection,
3534                    proto::DeleteNotification {
3535                        notification_id: (*notification_id).to_proto(),
3536                    },
3537                )?;
3538            }
3539
3540            Ok(())
3541        },
3542    );
3543    response.send(proto::Ack {})?;
3544    Ok(())
3545}
3546
3547async fn update_channel_message(
3548    request: proto::UpdateChannelMessage,
3549    response: Response<proto::UpdateChannelMessage>,
3550    session: Session,
3551) -> Result<()> {
3552    let channel_id = ChannelId::from_proto(request.channel_id);
3553    let message_id = MessageId::from_proto(request.message_id);
3554    let updated_at = OffsetDateTime::now_utc();
3555    let UpdatedChannelMessage {
3556        message_id,
3557        participant_connection_ids,
3558        notifications,
3559        reply_to_message_id,
3560        timestamp,
3561        deleted_mention_notification_ids,
3562        updated_mention_notifications,
3563    } = session
3564        .db()
3565        .await
3566        .update_channel_message(
3567            channel_id,
3568            message_id,
3569            session.user_id(),
3570            request.body.as_str(),
3571            &request.mentions,
3572            updated_at,
3573        )
3574        .await?;
3575
3576    let nonce = request
3577        .nonce
3578        .clone()
3579        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3580
3581    let message = proto::ChannelMessage {
3582        sender_id: session.user_id().to_proto(),
3583        id: message_id.to_proto(),
3584        body: request.body.clone(),
3585        mentions: request.mentions.clone(),
3586        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3587        nonce: Some(nonce),
3588        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3589        edited_at: Some(updated_at.unix_timestamp() as u64),
3590    };
3591
3592    response.send(proto::Ack {})?;
3593
3594    let pool = &*session.connection_pool().await;
3595    broadcast(
3596        Some(session.connection_id),
3597        participant_connection_ids,
3598        |connection| {
3599            session.peer.send(
3600                connection,
3601                proto::ChannelMessageUpdate {
3602                    channel_id: channel_id.to_proto(),
3603                    message: Some(message.clone()),
3604                },
3605            )?;
3606
3607            for notification_id in &deleted_mention_notification_ids {
3608                session.peer.send(
3609                    connection,
3610                    proto::DeleteNotification {
3611                        notification_id: (*notification_id).to_proto(),
3612                    },
3613                )?;
3614            }
3615
3616            for notification in &updated_mention_notifications {
3617                session.peer.send(
3618                    connection,
3619                    proto::UpdateNotification {
3620                        notification: Some(notification.clone()),
3621                    },
3622                )?;
3623            }
3624
3625            Ok(())
3626        },
3627    );
3628
3629    send_notifications(pool, &session.peer, notifications);
3630
3631    Ok(())
3632}
3633
3634/// Mark a channel message as read
3635async fn acknowledge_channel_message(
3636    request: proto::AckChannelMessage,
3637    session: Session,
3638) -> Result<()> {
3639    let channel_id = ChannelId::from_proto(request.channel_id);
3640    let message_id = MessageId::from_proto(request.message_id);
3641    let notifications = session
3642        .db()
3643        .await
3644        .observe_channel_message(channel_id, session.user_id(), message_id)
3645        .await?;
3646    send_notifications(
3647        &*session.connection_pool().await,
3648        &session.peer,
3649        notifications,
3650    );
3651    Ok(())
3652}
3653
3654/// Mark a buffer version as synced
3655async fn acknowledge_buffer_version(
3656    request: proto::AckBufferOperation,
3657    session: Session,
3658) -> Result<()> {
3659    let buffer_id = BufferId::from_proto(request.buffer_id);
3660    session
3661        .db()
3662        .await
3663        .observe_buffer_version(
3664            buffer_id,
3665            session.user_id(),
3666            request.epoch as i32,
3667            &request.version,
3668        )
3669        .await?;
3670    Ok(())
3671}
3672
3673/// Get a Supermaven API key for the user
3674async fn get_supermaven_api_key(
3675    _request: proto::GetSupermavenApiKey,
3676    response: Response<proto::GetSupermavenApiKey>,
3677    session: Session,
3678) -> Result<()> {
3679    let user_id: String = session.user_id().to_string();
3680    if !session.is_staff() {
3681        return Err(anyhow!("supermaven not enabled for this account"))?;
3682    }
3683
3684    let email = session
3685        .email()
3686        .ok_or_else(|| anyhow!("user must have an email"))?;
3687
3688    let supermaven_admin_api = session
3689        .supermaven_client
3690        .as_ref()
3691        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3692
3693    let result = supermaven_admin_api
3694        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3695        .await?;
3696
3697    response.send(proto::GetSupermavenApiKeyResponse {
3698        api_key: result.api_key,
3699    })?;
3700
3701    Ok(())
3702}
3703
3704/// Start receiving chat updates for a channel
3705async fn join_channel_chat(
3706    request: proto::JoinChannelChat,
3707    response: Response<proto::JoinChannelChat>,
3708    session: Session,
3709) -> Result<()> {
3710    let channel_id = ChannelId::from_proto(request.channel_id);
3711
3712    let db = session.db().await;
3713    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3714        .await?;
3715    let messages = db
3716        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3717        .await?;
3718    response.send(proto::JoinChannelChatResponse {
3719        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3720        messages,
3721    })?;
3722    Ok(())
3723}
3724
3725/// Stop receiving chat updates for a channel
3726async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3727    let channel_id = ChannelId::from_proto(request.channel_id);
3728    session
3729        .db()
3730        .await
3731        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3732        .await?;
3733    Ok(())
3734}
3735
3736/// Retrieve the chat history for a channel
3737async fn get_channel_messages(
3738    request: proto::GetChannelMessages,
3739    response: Response<proto::GetChannelMessages>,
3740    session: Session,
3741) -> Result<()> {
3742    let channel_id = ChannelId::from_proto(request.channel_id);
3743    let messages = session
3744        .db()
3745        .await
3746        .get_channel_messages(
3747            channel_id,
3748            session.user_id(),
3749            MESSAGE_COUNT_PER_PAGE,
3750            Some(MessageId::from_proto(request.before_message_id)),
3751        )
3752        .await?;
3753    response.send(proto::GetChannelMessagesResponse {
3754        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3755        messages,
3756    })?;
3757    Ok(())
3758}
3759
3760/// Retrieve specific chat messages
3761async fn get_channel_messages_by_id(
3762    request: proto::GetChannelMessagesById,
3763    response: Response<proto::GetChannelMessagesById>,
3764    session: Session,
3765) -> Result<()> {
3766    let message_ids = request
3767        .message_ids
3768        .iter()
3769        .map(|id| MessageId::from_proto(*id))
3770        .collect::<Vec<_>>();
3771    let messages = session
3772        .db()
3773        .await
3774        .get_channel_messages_by_id(session.user_id(), &message_ids)
3775        .await?;
3776    response.send(proto::GetChannelMessagesResponse {
3777        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3778        messages,
3779    })?;
3780    Ok(())
3781}
3782
3783/// Retrieve the current users notifications
3784async fn get_notifications(
3785    request: proto::GetNotifications,
3786    response: Response<proto::GetNotifications>,
3787    session: Session,
3788) -> Result<()> {
3789    let notifications = session
3790        .db()
3791        .await
3792        .get_notifications(
3793            session.user_id(),
3794            NOTIFICATION_COUNT_PER_PAGE,
3795            request.before_id.map(db::NotificationId::from_proto),
3796        )
3797        .await?;
3798    response.send(proto::GetNotificationsResponse {
3799        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3800        notifications,
3801    })?;
3802    Ok(())
3803}
3804
3805/// Mark notifications as read
3806async fn mark_notification_as_read(
3807    request: proto::MarkNotificationRead,
3808    response: Response<proto::MarkNotificationRead>,
3809    session: Session,
3810) -> Result<()> {
3811    let database = &session.db().await;
3812    let notifications = database
3813        .mark_notification_as_read_by_id(
3814            session.user_id(),
3815            NotificationId::from_proto(request.notification_id),
3816        )
3817        .await?;
3818    send_notifications(
3819        &*session.connection_pool().await,
3820        &session.peer,
3821        notifications,
3822    );
3823    response.send(proto::Ack {})?;
3824    Ok(())
3825}
3826
3827/// Get the current users information
3828async fn get_private_user_info(
3829    _request: proto::GetPrivateUserInfo,
3830    response: Response<proto::GetPrivateUserInfo>,
3831    session: Session,
3832) -> Result<()> {
3833    let db = session.db().await;
3834
3835    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3836    let user = db
3837        .get_user_by_id(session.user_id())
3838        .await?
3839        .ok_or_else(|| anyhow!("user not found"))?;
3840    let flags = db.get_user_flags(session.user_id()).await?;
3841
3842    response.send(proto::GetPrivateUserInfoResponse {
3843        metrics_id,
3844        staff: user.admin,
3845        flags,
3846        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3847    })?;
3848    Ok(())
3849}
3850
3851/// Accept the terms of service (tos) on behalf of the current user
3852async fn accept_terms_of_service(
3853    _request: proto::AcceptTermsOfService,
3854    response: Response<proto::AcceptTermsOfService>,
3855    session: Session,
3856) -> Result<()> {
3857    let db = session.db().await;
3858
3859    let accepted_tos_at = Utc::now();
3860    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3861        .await?;
3862
3863    response.send(proto::AcceptTermsOfServiceResponse {
3864        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3865    })?;
3866    Ok(())
3867}
3868
3869/// The minimum account age an account must have in order to use the LLM service.
3870pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3871
3872async fn get_llm_api_token(
3873    _request: proto::GetLlmToken,
3874    response: Response<proto::GetLlmToken>,
3875    session: Session,
3876) -> Result<()> {
3877    let db = session.db().await;
3878
3879    let flags = db.get_user_flags(session.user_id()).await?;
3880    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3881
3882    if !session.is_staff() && !has_language_models_feature_flag {
3883        Err(anyhow!("permission denied"))?
3884    }
3885
3886    let user_id = session.user_id();
3887    let user = db
3888        .get_user_by_id(user_id)
3889        .await?
3890        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3891
3892    if user.accepted_tos_at.is_none() {
3893        Err(anyhow!("terms of service not accepted"))?
3894    }
3895
3896    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
3897    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
3898    let billing_preferences = db.get_billing_preferences(user.id).await?;
3899
3900    let token = LlmTokenClaims::create(
3901        &user,
3902        session.is_staff(),
3903        billing_preferences,
3904        &flags,
3905        has_legacy_llm_subscription,
3906        billing_subscription,
3907        session.system_id.clone(),
3908        &session.app_state.config,
3909    )?;
3910    response.send(proto::GetLlmTokenResponse { token })?;
3911    Ok(())
3912}
3913
3914fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3915    let message = match message {
3916        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3917        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3918        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3919        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3920        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3921            code: frame.code.into(),
3922            reason: frame.reason.as_str().to_owned().into(),
3923        })),
3924        // We should never receive a frame while reading the message, according
3925        // to the `tungstenite` maintainers:
3926        //
3927        // > It cannot occur when you read messages from the WebSocket, but it
3928        // > can be used when you want to send the raw frames (e.g. you want to
3929        // > send the frames to the WebSocket without composing the full message first).
3930        // >
3931        // > — https://github.com/snapview/tungstenite-rs/issues/268
3932        TungsteniteMessage::Frame(_) => {
3933            bail!("received an unexpected frame while reading the message")
3934        }
3935    };
3936
3937    Ok(message)
3938}
3939
3940fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3941    match message {
3942        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3943        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3944        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3945        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3946        AxumMessage::Close(frame) => {
3947            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3948                code: frame.code.into(),
3949                reason: frame.reason.as_ref().into(),
3950            }))
3951        }
3952    }
3953}
3954
3955fn notify_membership_updated(
3956    connection_pool: &mut ConnectionPool,
3957    result: MembershipUpdated,
3958    user_id: UserId,
3959    peer: &Peer,
3960) {
3961    for membership in &result.new_channels.channel_memberships {
3962        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3963    }
3964    for channel_id in &result.removed_channels {
3965        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3966    }
3967
3968    let user_channels_update = proto::UpdateUserChannels {
3969        channel_memberships: result
3970            .new_channels
3971            .channel_memberships
3972            .iter()
3973            .map(|cm| proto::ChannelMembership {
3974                channel_id: cm.channel_id.to_proto(),
3975                role: cm.role.into(),
3976            })
3977            .collect(),
3978        ..Default::default()
3979    };
3980
3981    let mut update = build_channels_update(result.new_channels);
3982    update.delete_channels = result
3983        .removed_channels
3984        .into_iter()
3985        .map(|id| id.to_proto())
3986        .collect();
3987    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3988
3989    for connection_id in connection_pool.user_connection_ids(user_id) {
3990        peer.send(connection_id, user_channels_update.clone())
3991            .trace_err();
3992        peer.send(connection_id, update.clone()).trace_err();
3993    }
3994}
3995
3996fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3997    proto::UpdateUserChannels {
3998        channel_memberships: channels
3999            .channel_memberships
4000            .iter()
4001            .map(|m| proto::ChannelMembership {
4002                channel_id: m.channel_id.to_proto(),
4003                role: m.role.into(),
4004            })
4005            .collect(),
4006        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4007        observed_channel_message_id: channels.observed_channel_messages.clone(),
4008    }
4009}
4010
4011fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4012    let mut update = proto::UpdateChannels::default();
4013
4014    for channel in channels.channels {
4015        update.channels.push(channel.to_proto());
4016    }
4017
4018    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4019    update.latest_channel_message_ids = channels.latest_channel_messages;
4020
4021    for (channel_id, participants) in channels.channel_participants {
4022        update
4023            .channel_participants
4024            .push(proto::ChannelParticipants {
4025                channel_id: channel_id.to_proto(),
4026                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4027            });
4028    }
4029
4030    for channel in channels.invited_channels {
4031        update.channel_invitations.push(channel.to_proto());
4032    }
4033
4034    update
4035}
4036
4037fn build_initial_contacts_update(
4038    contacts: Vec<db::Contact>,
4039    pool: &ConnectionPool,
4040) -> proto::UpdateContacts {
4041    let mut update = proto::UpdateContacts::default();
4042
4043    for contact in contacts {
4044        match contact {
4045            db::Contact::Accepted { user_id, busy } => {
4046                update.contacts.push(contact_for_user(user_id, busy, pool));
4047            }
4048            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4049            db::Contact::Incoming { user_id } => {
4050                update
4051                    .incoming_requests
4052                    .push(proto::IncomingContactRequest {
4053                        requester_id: user_id.to_proto(),
4054                    })
4055            }
4056        }
4057    }
4058
4059    update
4060}
4061
4062fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4063    proto::Contact {
4064        user_id: user_id.to_proto(),
4065        online: pool.is_user_online(user_id),
4066        busy,
4067    }
4068}
4069
4070fn room_updated(room: &proto::Room, peer: &Peer) {
4071    broadcast(
4072        None,
4073        room.participants
4074            .iter()
4075            .filter_map(|participant| Some(participant.peer_id?.into())),
4076        |peer_id| {
4077            peer.send(
4078                peer_id,
4079                proto::RoomUpdated {
4080                    room: Some(room.clone()),
4081                },
4082            )
4083        },
4084    );
4085}
4086
4087fn channel_updated(
4088    channel: &db::channel::Model,
4089    room: &proto::Room,
4090    peer: &Peer,
4091    pool: &ConnectionPool,
4092) {
4093    let participants = room
4094        .participants
4095        .iter()
4096        .map(|p| p.user_id)
4097        .collect::<Vec<_>>();
4098
4099    broadcast(
4100        None,
4101        pool.channel_connection_ids(channel.root_id())
4102            .filter_map(|(channel_id, role)| {
4103                role.can_see_channel(channel.visibility)
4104                    .then_some(channel_id)
4105            }),
4106        |peer_id| {
4107            peer.send(
4108                peer_id,
4109                proto::UpdateChannels {
4110                    channel_participants: vec![proto::ChannelParticipants {
4111                        channel_id: channel.id.to_proto(),
4112                        participant_user_ids: participants.clone(),
4113                    }],
4114                    ..Default::default()
4115                },
4116            )
4117        },
4118    );
4119}
4120
4121async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4122    let db = session.db().await;
4123
4124    let contacts = db.get_contacts(user_id).await?;
4125    let busy = db.is_user_busy(user_id).await?;
4126
4127    let pool = session.connection_pool().await;
4128    let updated_contact = contact_for_user(user_id, busy, &pool);
4129    for contact in contacts {
4130        if let db::Contact::Accepted {
4131            user_id: contact_user_id,
4132            ..
4133        } = contact
4134        {
4135            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4136                session
4137                    .peer
4138                    .send(
4139                        contact_conn_id,
4140                        proto::UpdateContacts {
4141                            contacts: vec![updated_contact.clone()],
4142                            remove_contacts: Default::default(),
4143                            incoming_requests: Default::default(),
4144                            remove_incoming_requests: Default::default(),
4145                            outgoing_requests: Default::default(),
4146                            remove_outgoing_requests: Default::default(),
4147                        },
4148                    )
4149                    .trace_err();
4150            }
4151        }
4152    }
4153    Ok(())
4154}
4155
4156async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4157    let mut contacts_to_update = HashSet::default();
4158
4159    let room_id;
4160    let canceled_calls_to_user_ids;
4161    let livekit_room;
4162    let delete_livekit_room;
4163    let room;
4164    let channel;
4165
4166    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4167        contacts_to_update.insert(session.user_id());
4168
4169        for project in left_room.left_projects.values() {
4170            project_left(project, session);
4171        }
4172
4173        room_id = RoomId::from_proto(left_room.room.id);
4174        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4175        livekit_room = mem::take(&mut left_room.room.livekit_room);
4176        delete_livekit_room = left_room.deleted;
4177        room = mem::take(&mut left_room.room);
4178        channel = mem::take(&mut left_room.channel);
4179
4180        room_updated(&room, &session.peer);
4181    } else {
4182        return Ok(());
4183    }
4184
4185    if let Some(channel) = channel {
4186        channel_updated(
4187            &channel,
4188            &room,
4189            &session.peer,
4190            &*session.connection_pool().await,
4191        );
4192    }
4193
4194    {
4195        let pool = session.connection_pool().await;
4196        for canceled_user_id in canceled_calls_to_user_ids {
4197            for connection_id in pool.user_connection_ids(canceled_user_id) {
4198                session
4199                    .peer
4200                    .send(
4201                        connection_id,
4202                        proto::CallCanceled {
4203                            room_id: room_id.to_proto(),
4204                        },
4205                    )
4206                    .trace_err();
4207            }
4208            contacts_to_update.insert(canceled_user_id);
4209        }
4210    }
4211
4212    for contact_user_id in contacts_to_update {
4213        update_user_contacts(contact_user_id, session).await?;
4214    }
4215
4216    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4217        live_kit
4218            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4219            .await
4220            .trace_err();
4221
4222        if delete_livekit_room {
4223            live_kit.delete_room(livekit_room).await.trace_err();
4224        }
4225    }
4226
4227    Ok(())
4228}
4229
4230async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4231    let left_channel_buffers = session
4232        .db()
4233        .await
4234        .leave_channel_buffers(session.connection_id)
4235        .await?;
4236
4237    for left_buffer in left_channel_buffers {
4238        channel_buffer_updated(
4239            session.connection_id,
4240            left_buffer.connections,
4241            &proto::UpdateChannelBufferCollaborators {
4242                channel_id: left_buffer.channel_id.to_proto(),
4243                collaborators: left_buffer.collaborators,
4244            },
4245            &session.peer,
4246        );
4247    }
4248
4249    Ok(())
4250}
4251
4252fn project_left(project: &db::LeftProject, session: &Session) {
4253    for connection_id in &project.connection_ids {
4254        if project.should_unshare {
4255            session
4256                .peer
4257                .send(
4258                    *connection_id,
4259                    proto::UnshareProject {
4260                        project_id: project.id.to_proto(),
4261                    },
4262                )
4263                .trace_err();
4264        } else {
4265            session
4266                .peer
4267                .send(
4268                    *connection_id,
4269                    proto::RemoveProjectCollaborator {
4270                        project_id: project.id.to_proto(),
4271                        peer_id: Some(session.connection_id.into()),
4272                    },
4273                )
4274                .trace_err();
4275        }
4276    }
4277}
4278
4279pub trait ResultExt {
4280    type Ok;
4281
4282    fn trace_err(self) -> Option<Self::Ok>;
4283}
4284
4285impl<T, E> ResultExt for Result<T, E>
4286where
4287    E: std::fmt::Debug,
4288{
4289    type Ok = T;
4290
4291    #[track_caller]
4292    fn trace_err(self) -> Option<T> {
4293        match self {
4294            Ok(value) => Some(value),
4295            Err(error) => {
4296                tracing::error!("{:?}", error);
4297                None
4298            }
4299        }
4300    }
4301}