rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Config, Error, RateLimit, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use http_client::HttpClient;
  37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
  38use reqwest_client::ReqwestClient;
  39use rpc::proto::split_repository_update;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{MutexGuard, Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 311            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 312            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 313            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 320            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(
 325                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 326            )
 327            .add_request_handler(
 328                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 329            )
 330            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 331            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 332            .add_request_handler(
 333                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 334            )
 335            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 358            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 359            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 360            .add_message_handler(create_buffer_for_peer)
 361            .add_request_handler(update_buffer)
 362            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 363            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 368            .add_request_handler(get_users)
 369            .add_request_handler(fuzzy_search_users)
 370            .add_request_handler(request_contact)
 371            .add_request_handler(remove_contact)
 372            .add_request_handler(respond_to_contact_request)
 373            .add_message_handler(subscribe_to_channels)
 374            .add_request_handler(create_channel)
 375            .add_request_handler(delete_channel)
 376            .add_request_handler(invite_channel_member)
 377            .add_request_handler(remove_channel_member)
 378            .add_request_handler(set_channel_member_role)
 379            .add_request_handler(set_channel_visibility)
 380            .add_request_handler(rename_channel)
 381            .add_request_handler(join_channel_buffer)
 382            .add_request_handler(leave_channel_buffer)
 383            .add_message_handler(update_channel_buffer)
 384            .add_request_handler(rejoin_channel_buffers)
 385            .add_request_handler(get_channel_members)
 386            .add_request_handler(respond_to_channel_invite)
 387            .add_request_handler(join_channel)
 388            .add_request_handler(join_channel_chat)
 389            .add_message_handler(leave_channel_chat)
 390            .add_request_handler(send_channel_message)
 391            .add_request_handler(remove_channel_message)
 392            .add_request_handler(update_channel_message)
 393            .add_request_handler(get_channel_messages)
 394            .add_request_handler(get_channel_messages_by_id)
 395            .add_request_handler(get_notifications)
 396            .add_request_handler(mark_notification_as_read)
 397            .add_request_handler(move_channel)
 398            .add_request_handler(follow)
 399            .add_message_handler(unfollow)
 400            .add_message_handler(update_followers)
 401            .add_request_handler(get_private_user_info)
 402            .add_request_handler(get_llm_api_token)
 403            .add_request_handler(accept_terms_of_service)
 404            .add_message_handler(acknowledge_channel_message)
 405            .add_message_handler(acknowledge_buffer_version)
 406            .add_request_handler(get_supermaven_api_key)
 407            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 409            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 410            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 411            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 414            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 416            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 417            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 418            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 419            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 420            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 421            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 422            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 424            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 425            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 427            .add_message_handler(update_context)
 428            .add_request_handler({
 429                let app_state = app_state.clone();
 430                move |request, response, session| {
 431                    let app_state = app_state.clone();
 432                    async move {
 433                        count_language_model_tokens(request, response, session, &app_state.config)
 434                            .await
 435                    }
 436                }
 437            })
 438            .add_request_handler(get_cached_embeddings)
 439            .add_request_handler({
 440                let app_state = app_state.clone();
 441                move |request, response, session| {
 442                    compute_embeddings(
 443                        request,
 444                        response,
 445                        session,
 446                        app_state.config.openai_api_key.clone(),
 447                    )
 448                }
 449            });
 450
 451        Arc::new(server)
 452    }
 453
 454    pub async fn start(&self) -> Result<()> {
 455        let server_id = *self.id.lock();
 456        let app_state = self.app_state.clone();
 457        let peer = self.peer.clone();
 458        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 459        let pool = self.connection_pool.clone();
 460        let livekit_client = self.app_state.livekit_client.clone();
 461
 462        let span = info_span!("start server");
 463        self.app_state.executor.spawn_detached(
 464            async move {
 465                tracing::info!("waiting for cleanup timeout");
 466                timeout.await;
 467                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 468                if let Some((room_ids, channel_ids)) = app_state
 469                    .db
 470                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 471                    .await
 472                    .trace_err()
 473                {
 474                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 475                    tracing::info!(
 476                        stale_channel_buffer_count = channel_ids.len(),
 477                        "retrieved stale channel buffers"
 478                    );
 479
 480                    for channel_id in channel_ids {
 481                        if let Some(refreshed_channel_buffer) = app_state
 482                            .db
 483                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 484                            .await
 485                            .trace_err()
 486                        {
 487                            for connection_id in refreshed_channel_buffer.connection_ids {
 488                                peer.send(
 489                                    connection_id,
 490                                    proto::UpdateChannelBufferCollaborators {
 491                                        channel_id: channel_id.to_proto(),
 492                                        collaborators: refreshed_channel_buffer
 493                                            .collaborators
 494                                            .clone(),
 495                                    },
 496                                )
 497                                .trace_err();
 498                            }
 499                        }
 500                    }
 501
 502                    for room_id in room_ids {
 503                        let mut contacts_to_update = HashSet::default();
 504                        let mut canceled_calls_to_user_ids = Vec::new();
 505                        let mut livekit_room = String::new();
 506                        let mut delete_livekit_room = false;
 507
 508                        if let Some(mut refreshed_room) = app_state
 509                            .db
 510                            .clear_stale_room_participants(room_id, server_id)
 511                            .await
 512                            .trace_err()
 513                        {
 514                            tracing::info!(
 515                                room_id = room_id.0,
 516                                new_participant_count = refreshed_room.room.participants.len(),
 517                                "refreshed room"
 518                            );
 519                            room_updated(&refreshed_room.room, &peer);
 520                            if let Some(channel) = refreshed_room.channel.as_ref() {
 521                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 522                            }
 523                            contacts_to_update
 524                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 525                            contacts_to_update
 526                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 527                            canceled_calls_to_user_ids =
 528                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 529                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 530                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 531                        }
 532
 533                        {
 534                            let pool = pool.lock();
 535                            for canceled_user_id in canceled_calls_to_user_ids {
 536                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 537                                    peer.send(
 538                                        connection_id,
 539                                        proto::CallCanceled {
 540                                            room_id: room_id.to_proto(),
 541                                        },
 542                                    )
 543                                    .trace_err();
 544                                }
 545                            }
 546                        }
 547
 548                        for user_id in contacts_to_update {
 549                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 550                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 551                            if let Some((busy, contacts)) = busy.zip(contacts) {
 552                                let pool = pool.lock();
 553                                let updated_contact = contact_for_user(user_id, busy, &pool);
 554                                for contact in contacts {
 555                                    if let db::Contact::Accepted {
 556                                        user_id: contact_user_id,
 557                                        ..
 558                                    } = contact
 559                                    {
 560                                        for contact_conn_id in
 561                                            pool.user_connection_ids(contact_user_id)
 562                                        {
 563                                            peer.send(
 564                                                contact_conn_id,
 565                                                proto::UpdateContacts {
 566                                                    contacts: vec![updated_contact.clone()],
 567                                                    remove_contacts: Default::default(),
 568                                                    incoming_requests: Default::default(),
 569                                                    remove_incoming_requests: Default::default(),
 570                                                    outgoing_requests: Default::default(),
 571                                                    remove_outgoing_requests: Default::default(),
 572                                                },
 573                                            )
 574                                            .trace_err();
 575                                        }
 576                                    }
 577                                }
 578                            }
 579                        }
 580
 581                        if let Some(live_kit) = livekit_client.as_ref() {
 582                            if delete_livekit_room {
 583                                live_kit.delete_room(livekit_room).await.trace_err();
 584                            }
 585                        }
 586                    }
 587                }
 588
 589                app_state
 590                    .db
 591                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 592                    .await
 593                    .trace_err();
 594            }
 595            .instrument(span),
 596        );
 597        Ok(())
 598    }
 599
 600    pub fn teardown(&self) {
 601        self.peer.teardown();
 602        self.connection_pool.lock().reset();
 603        let _ = self.teardown.send(true);
 604    }
 605
 606    #[cfg(test)]
 607    pub fn reset(&self, id: ServerId) {
 608        self.teardown();
 609        *self.id.lock() = id;
 610        self.peer.reset(id.0 as u32);
 611        let _ = self.teardown.send(false);
 612    }
 613
 614    #[cfg(test)]
 615    pub fn id(&self) -> ServerId {
 616        *self.id.lock()
 617    }
 618
 619    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 620    where
 621        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 622        Fut: 'static + Send + Future<Output = Result<()>>,
 623        M: EnvelopedMessage,
 624    {
 625        let prev_handler = self.handlers.insert(
 626            TypeId::of::<M>(),
 627            Box::new(move |envelope, session| {
 628                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 629                let received_at = envelope.received_at;
 630                tracing::info!("message received");
 631                let start_time = Instant::now();
 632                let future = (handler)(*envelope, session);
 633                async move {
 634                    let result = future.await;
 635                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 636                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 637                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 638                    let payload_type = M::NAME;
 639
 640                    match result {
 641                        Err(error) => {
 642                            tracing::error!(
 643                                ?error,
 644                                total_duration_ms,
 645                                processing_duration_ms,
 646                                queue_duration_ms,
 647                                payload_type,
 648                                "error handling message"
 649                            )
 650                        }
 651                        Ok(()) => tracing::info!(
 652                            total_duration_ms,
 653                            processing_duration_ms,
 654                            queue_duration_ms,
 655                            "finished handling message"
 656                        ),
 657                    }
 658                }
 659                .boxed()
 660            }),
 661        );
 662        if prev_handler.is_some() {
 663            panic!("registered a handler for the same message twice");
 664        }
 665        self
 666    }
 667
 668    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 669    where
 670        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 671        Fut: 'static + Send + Future<Output = Result<()>>,
 672        M: EnvelopedMessage,
 673    {
 674        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 675        self
 676    }
 677
 678    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 679    where
 680        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 681        Fut: Send + Future<Output = Result<()>>,
 682        M: RequestMessage,
 683    {
 684        let handler = Arc::new(handler);
 685        self.add_handler(move |envelope, session| {
 686            let receipt = envelope.receipt();
 687            let handler = handler.clone();
 688            async move {
 689                let peer = session.peer.clone();
 690                let responded = Arc::new(AtomicBool::default());
 691                let response = Response {
 692                    peer: peer.clone(),
 693                    responded: responded.clone(),
 694                    receipt,
 695                };
 696                match (handler)(envelope.payload, response, session).await {
 697                    Ok(()) => {
 698                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 699                            Ok(())
 700                        } else {
 701                            Err(anyhow!("handler did not send a response"))?
 702                        }
 703                    }
 704                    Err(error) => {
 705                        let proto_err = match &error {
 706                            Error::Internal(err) => err.to_proto(),
 707                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 708                        };
 709                        peer.respond_with_error(receipt, proto_err)?;
 710                        Err(error)
 711                    }
 712                }
 713            }
 714        })
 715    }
 716
 717    pub fn handle_connection(
 718        self: &Arc<Self>,
 719        connection: Connection,
 720        address: String,
 721        principal: Principal,
 722        zed_version: ZedVersion,
 723        geoip_country_code: Option<String>,
 724        system_id: Option<String>,
 725        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 726        executor: Executor,
 727    ) -> impl Future<Output = ()> + use<> {
 728        let this = self.clone();
 729        let span = info_span!("handle connection", %address,
 730            connection_id=field::Empty,
 731            user_id=field::Empty,
 732            login=field::Empty,
 733            impersonator=field::Empty,
 734            geoip_country_code=field::Empty
 735        );
 736        principal.update_span(&span);
 737        if let Some(country_code) = geoip_country_code.as_ref() {
 738            span.record("geoip_country_code", country_code);
 739        }
 740
 741        let mut teardown = self.teardown.subscribe();
 742        async move {
 743            if *teardown.borrow() {
 744                tracing::error!("server is tearing down");
 745                return
 746            }
 747            let (connection_id, handle_io, mut incoming_rx) = this
 748                .peer
 749                .add_connection(connection, {
 750                    let executor = executor.clone();
 751                    move |duration| executor.sleep(duration)
 752                });
 753            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 754
 755            tracing::info!("connection opened");
 756
 757            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 758            let http_client = match ReqwestClient::user_agent(&user_agent) {
 759                Ok(http_client) => Arc::new(http_client),
 760                Err(error) => {
 761                    tracing::error!(?error, "failed to create HTTP client");
 762                    return;
 763                }
 764            };
 765
 766            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 767                    supermaven_admin_api_key.to_string(),
 768                    http_client.clone(),
 769                )));
 770
 771            let session = Session {
 772                principal: principal.clone(),
 773                connection_id,
 774                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 775                peer: this.peer.clone(),
 776                connection_pool: this.connection_pool.clone(),
 777                app_state: this.app_state.clone(),
 778                http_client,
 779                geoip_country_code,
 780                system_id,
 781                _executor: executor.clone(),
 782                supermaven_client,
 783            };
 784
 785            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 786                tracing::error!(?error, "failed to send initial client update");
 787                return;
 788            }
 789
 790            let handle_io = handle_io.fuse();
 791            futures::pin_mut!(handle_io);
 792
 793            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 794            // This prevents deadlocks when e.g., client A performs a request to client B and
 795            // client B performs a request to client A. If both clients stop processing further
 796            // messages until their respective request completes, they won't have a chance to
 797            // respond to the other client's request and cause a deadlock.
 798            //
 799            // This arrangement ensures we will attempt to process earlier messages first, but fall
 800            // back to processing messages arrived later in the spirit of making progress.
 801            let mut foreground_message_handlers = FuturesUnordered::new();
 802            let concurrent_handlers = Arc::new(Semaphore::new(256));
 803            loop {
 804                let next_message = async {
 805                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 806                    let message = incoming_rx.next().await;
 807                    (permit, message)
 808                }.fuse();
 809                futures::pin_mut!(next_message);
 810                futures::select_biased! {
 811                    _ = teardown.changed().fuse() => return,
 812                    result = handle_io => {
 813                        if let Err(error) = result {
 814                            tracing::error!(?error, "error handling I/O");
 815                        }
 816                        break;
 817                    }
 818                    _ = foreground_message_handlers.next() => {}
 819                    next_message = next_message => {
 820                        let (permit, message) = next_message;
 821                        if let Some(message) = message {
 822                            let type_name = message.payload_type_name();
 823                            // note: we copy all the fields from the parent span so we can query them in the logs.
 824                            // (https://github.com/tokio-rs/tracing/issues/2670).
 825                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 826                                user_id=field::Empty,
 827                                login=field::Empty,
 828                                impersonator=field::Empty,
 829                            );
 830                            principal.update_span(&span);
 831                            let span_enter = span.enter();
 832                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 833                                let is_background = message.is_background();
 834                                let handle_message = (handler)(message, session.clone());
 835                                drop(span_enter);
 836
 837                                let handle_message = async move {
 838                                    handle_message.await;
 839                                    drop(permit);
 840                                }.instrument(span);
 841                                if is_background {
 842                                    executor.spawn_detached(handle_message);
 843                                } else {
 844                                    foreground_message_handlers.push(handle_message);
 845                                }
 846                            } else {
 847                                tracing::error!("no message handler");
 848                            }
 849                        } else {
 850                            tracing::info!("connection closed");
 851                            break;
 852                        }
 853                    }
 854                }
 855            }
 856
 857            drop(foreground_message_handlers);
 858            tracing::info!("signing out");
 859            if let Err(error) = connection_lost(session, teardown, executor).await {
 860                tracing::error!(?error, "error signing out");
 861            }
 862
 863        }.instrument(span)
 864    }
 865
 866    async fn send_initial_client_update(
 867        &self,
 868        connection_id: ConnectionId,
 869        principal: &Principal,
 870        zed_version: ZedVersion,
 871        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 872        session: &Session,
 873    ) -> Result<()> {
 874        self.peer.send(
 875            connection_id,
 876            proto::Hello {
 877                peer_id: Some(connection_id.into()),
 878            },
 879        )?;
 880        tracing::info!("sent hello message");
 881        if let Some(send_connection_id) = send_connection_id.take() {
 882            let _ = send_connection_id.send(connection_id);
 883        }
 884
 885        match principal {
 886            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 887                if !user.connected_once {
 888                    self.peer.send(connection_id, proto::ShowContacts {})?;
 889                    self.app_state
 890                        .db
 891                        .set_user_connected_once(user.id, true)
 892                        .await?;
 893                }
 894
 895                update_user_plan(user.id, session).await?;
 896
 897                let contacts = self.app_state.db.get_contacts(user.id).await?;
 898
 899                {
 900                    let mut pool = self.connection_pool.lock();
 901                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 902                    self.peer.send(
 903                        connection_id,
 904                        build_initial_contacts_update(contacts, &pool),
 905                    )?;
 906                }
 907
 908                if should_auto_subscribe_to_channels(zed_version) {
 909                    subscribe_user_to_channels(user.id, session).await?;
 910                }
 911
 912                if let Some(incoming_call) =
 913                    self.app_state.db.incoming_call_for_user(user.id).await?
 914                {
 915                    self.peer.send(connection_id, incoming_call)?;
 916                }
 917
 918                update_user_contacts(user.id, session).await?;
 919            }
 920        }
 921
 922        Ok(())
 923    }
 924
 925    pub async fn invite_code_redeemed(
 926        self: &Arc<Self>,
 927        inviter_id: UserId,
 928        invitee_id: UserId,
 929    ) -> Result<()> {
 930        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 931            if let Some(code) = &user.invite_code {
 932                let pool = self.connection_pool.lock();
 933                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 934                for connection_id in pool.user_connection_ids(inviter_id) {
 935                    self.peer.send(
 936                        connection_id,
 937                        proto::UpdateContacts {
 938                            contacts: vec![invitee_contact.clone()],
 939                            ..Default::default()
 940                        },
 941                    )?;
 942                    self.peer.send(
 943                        connection_id,
 944                        proto::UpdateInviteInfo {
 945                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 946                            count: user.invite_count as u32,
 947                        },
 948                    )?;
 949                }
 950            }
 951        }
 952        Ok(())
 953    }
 954
 955    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 956        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 957            if let Some(invite_code) = &user.invite_code {
 958                let pool = self.connection_pool.lock();
 959                for connection_id in pool.user_connection_ids(user_id) {
 960                    self.peer.send(
 961                        connection_id,
 962                        proto::UpdateInviteInfo {
 963                            url: format!(
 964                                "{}{}",
 965                                self.app_state.config.invite_link_prefix, invite_code
 966                            ),
 967                            count: user.invite_count as u32,
 968                        },
 969                    )?;
 970                }
 971            }
 972        }
 973        Ok(())
 974    }
 975
 976    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 977        let pool = self.connection_pool.lock();
 978        for connection_id in pool.user_connection_ids(user_id) {
 979            self.peer
 980                .send(connection_id, proto::RefreshLlmToken {})
 981                .trace_err();
 982        }
 983    }
 984
 985    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 986        ServerSnapshot {
 987            connection_pool: ConnectionPoolGuard {
 988                guard: self.connection_pool.lock(),
 989                _not_send: PhantomData,
 990            },
 991            peer: &self.peer,
 992        }
 993    }
 994}
 995
 996impl Deref for ConnectionPoolGuard<'_> {
 997    type Target = ConnectionPool;
 998
 999    fn deref(&self) -> &Self::Target {
1000        &self.guard
1001    }
1002}
1003
1004impl DerefMut for ConnectionPoolGuard<'_> {
1005    fn deref_mut(&mut self) -> &mut Self::Target {
1006        &mut self.guard
1007    }
1008}
1009
1010impl Drop for ConnectionPoolGuard<'_> {
1011    fn drop(&mut self) {
1012        #[cfg(test)]
1013        self.check_invariants();
1014    }
1015}
1016
1017fn broadcast<F>(
1018    sender_id: Option<ConnectionId>,
1019    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1020    mut f: F,
1021) where
1022    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1023{
1024    for receiver_id in receiver_ids {
1025        if Some(receiver_id) != sender_id {
1026            if let Err(error) = f(receiver_id) {
1027                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1028            }
1029        }
1030    }
1031}
1032
1033pub struct ProtocolVersion(u32);
1034
1035impl Header for ProtocolVersion {
1036    fn name() -> &'static HeaderName {
1037        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1039    }
1040
1041    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042    where
1043        Self: Sized,
1044        I: Iterator<Item = &'i axum::http::HeaderValue>,
1045    {
1046        let version = values
1047            .next()
1048            .ok_or_else(axum::headers::Error::invalid)?
1049            .to_str()
1050            .map_err(|_| axum::headers::Error::invalid())?
1051            .parse()
1052            .map_err(|_| axum::headers::Error::invalid())?;
1053        Ok(Self(version))
1054    }
1055
1056    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057        values.extend([self.0.to_string().parse().unwrap()]);
1058    }
1059}
1060
1061pub struct AppVersionHeader(SemanticVersion);
1062impl Header for AppVersionHeader {
1063    fn name() -> &'static HeaderName {
1064        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1065        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1066    }
1067
1068    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069    where
1070        Self: Sized,
1071        I: Iterator<Item = &'i axum::http::HeaderValue>,
1072    {
1073        let version = values
1074            .next()
1075            .ok_or_else(axum::headers::Error::invalid)?
1076            .to_str()
1077            .map_err(|_| axum::headers::Error::invalid())?
1078            .parse()
1079            .map_err(|_| axum::headers::Error::invalid())?;
1080        Ok(Self(version))
1081    }
1082
1083    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084        values.extend([self.0.to_string().parse().unwrap()]);
1085    }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089    Router::new()
1090        .route("/rpc", get(handle_websocket_request))
1091        .layer(
1092            ServiceBuilder::new()
1093                .layer(Extension(server.app_state.clone()))
1094                .layer(middleware::from_fn(auth::validate_header)),
1095        )
1096        .route("/metrics", get(handle_metrics))
1097        .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1104    Extension(server): Extension<Arc<Server>>,
1105    Extension(principal): Extension<Principal>,
1106    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1107    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1108    ws: WebSocketUpgrade,
1109) -> axum::response::Response {
1110    if protocol_version != rpc::PROTOCOL_VERSION {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "client must be upgraded".to_string(),
1114        )
1115            .into_response();
1116    }
1117
1118    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1119        return (
1120            StatusCode::UPGRADE_REQUIRED,
1121            "no version header found".to_string(),
1122        )
1123            .into_response();
1124    };
1125
1126    if !version.can_collaborate() {
1127        return (
1128            StatusCode::UPGRADE_REQUIRED,
1129            "client must be upgraded".to_string(),
1130        )
1131            .into_response();
1132    }
1133
1134    let socket_address = socket_address.to_string();
1135    ws.on_upgrade(move |socket| {
1136        let socket = socket
1137            .map_ok(to_tungstenite_message)
1138            .err_into()
1139            .with(|message| async move { to_axum_message(message) });
1140        let connection = Connection::new(Box::pin(socket));
1141        async move {
1142            server
1143                .handle_connection(
1144                    connection,
1145                    socket_address,
1146                    principal,
1147                    version,
1148                    country_code_header.map(|header| header.to_string()),
1149                    system_id_header.map(|header| header.to_string()),
1150                    None,
1151                    Executor::Production,
1152                )
1153                .await;
1154        }
1155    })
1156}
1157
1158pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1159    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1160    let connections_metric = CONNECTIONS_METRIC
1161        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1162
1163    let connections = server
1164        .connection_pool
1165        .lock()
1166        .connections()
1167        .filter(|connection| !connection.admin)
1168        .count();
1169    connections_metric.set(connections as _);
1170
1171    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1172    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1173        register_int_gauge!(
1174            "shared_projects",
1175            "number of open projects with one or more guests"
1176        )
1177        .unwrap()
1178    });
1179
1180    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1181    shared_projects_metric.set(shared_projects as _);
1182
1183    let encoder = prometheus::TextEncoder::new();
1184    let metric_families = prometheus::gather();
1185    let encoded_metrics = encoder
1186        .encode_to_string(&metric_families)
1187        .map_err(|err| anyhow!("{}", err))?;
1188    Ok(encoded_metrics)
1189}
1190
1191#[instrument(err, skip(executor))]
1192async fn connection_lost(
1193    session: Session,
1194    mut teardown: watch::Receiver<bool>,
1195    executor: Executor,
1196) -> Result<()> {
1197    session.peer.disconnect(session.connection_id);
1198    session
1199        .connection_pool()
1200        .await
1201        .remove_connection(session.connection_id)?;
1202
1203    session
1204        .db()
1205        .await
1206        .connection_lost(session.connection_id)
1207        .await
1208        .trace_err();
1209
1210    futures::select_biased! {
1211        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1212
1213            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1214            leave_room_for_session(&session, session.connection_id).await.trace_err();
1215            leave_channel_buffers_for_session(&session)
1216                .await
1217                .trace_err();
1218
1219            if !session
1220                .connection_pool()
1221                .await
1222                .is_user_online(session.user_id())
1223            {
1224                let db = session.db().await;
1225                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1226                    room_updated(&room, &session.peer);
1227                }
1228            }
1229
1230            update_user_contacts(session.user_id(), &session).await?;
1231        },
1232        _ = teardown.changed().fuse() => {}
1233    }
1234
1235    Ok(())
1236}
1237
1238/// Acknowledges a ping from a client, used to keep the connection alive.
1239async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1240    response.send(proto::Ack {})?;
1241    Ok(())
1242}
1243
1244/// Creates a new room for calling (outside of channels)
1245async fn create_room(
1246    _request: proto::CreateRoom,
1247    response: Response<proto::CreateRoom>,
1248    session: Session,
1249) -> Result<()> {
1250    let livekit_room = nanoid::nanoid!(30);
1251
1252    let live_kit_connection_info = util::maybe!(async {
1253        let live_kit = session.app_state.livekit_client.as_ref();
1254        let live_kit = live_kit?;
1255        let user_id = session.user_id().to_string();
1256
1257        let token = live_kit
1258            .room_token(&livekit_room, &user_id.to_string())
1259            .trace_err()?;
1260
1261        Some(proto::LiveKitConnectionInfo {
1262            server_url: live_kit.url().into(),
1263            token,
1264            can_publish: true,
1265        })
1266    })
1267    .await;
1268
1269    let room = session
1270        .db()
1271        .await
1272        .create_room(session.user_id(), session.connection_id, &livekit_room)
1273        .await?;
1274
1275    response.send(proto::CreateRoomResponse {
1276        room: Some(room.clone()),
1277        live_kit_connection_info,
1278    })?;
1279
1280    update_user_contacts(session.user_id(), &session).await?;
1281    Ok(())
1282}
1283
1284/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1285async fn join_room(
1286    request: proto::JoinRoom,
1287    response: Response<proto::JoinRoom>,
1288    session: Session,
1289) -> Result<()> {
1290    let room_id = RoomId::from_proto(request.id);
1291
1292    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1293
1294    if let Some(channel_id) = channel_id {
1295        return join_channel_internal(channel_id, Box::new(response), session).await;
1296    }
1297
1298    let joined_room = {
1299        let room = session
1300            .db()
1301            .await
1302            .join_room(room_id, session.user_id(), session.connection_id)
1303            .await?;
1304        room_updated(&room.room, &session.peer);
1305        room.into_inner()
1306    };
1307
1308    for connection_id in session
1309        .connection_pool()
1310        .await
1311        .user_connection_ids(session.user_id())
1312    {
1313        session
1314            .peer
1315            .send(
1316                connection_id,
1317                proto::CallCanceled {
1318                    room_id: room_id.to_proto(),
1319                },
1320            )
1321            .trace_err();
1322    }
1323
1324    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1325    {
1326        live_kit
1327            .room_token(
1328                &joined_room.room.livekit_room,
1329                &session.user_id().to_string(),
1330            )
1331            .trace_err()
1332            .map(|token| proto::LiveKitConnectionInfo {
1333                server_url: live_kit.url().into(),
1334                token,
1335                can_publish: true,
1336            })
1337    } else {
1338        None
1339    };
1340
1341    response.send(proto::JoinRoomResponse {
1342        room: Some(joined_room.room),
1343        channel_id: None,
1344        live_kit_connection_info,
1345    })?;
1346
1347    update_user_contacts(session.user_id(), &session).await?;
1348    Ok(())
1349}
1350
1351/// Rejoin room is used to reconnect to a room after connection errors.
1352async fn rejoin_room(
1353    request: proto::RejoinRoom,
1354    response: Response<proto::RejoinRoom>,
1355    session: Session,
1356) -> Result<()> {
1357    let room;
1358    let channel;
1359    {
1360        let mut rejoined_room = session
1361            .db()
1362            .await
1363            .rejoin_room(request, session.user_id(), session.connection_id)
1364            .await?;
1365
1366        response.send(proto::RejoinRoomResponse {
1367            room: Some(rejoined_room.room.clone()),
1368            reshared_projects: rejoined_room
1369                .reshared_projects
1370                .iter()
1371                .map(|project| proto::ResharedProject {
1372                    id: project.id.to_proto(),
1373                    collaborators: project
1374                        .collaborators
1375                        .iter()
1376                        .map(|collaborator| collaborator.to_proto())
1377                        .collect(),
1378                })
1379                .collect(),
1380            rejoined_projects: rejoined_room
1381                .rejoined_projects
1382                .iter()
1383                .map(|rejoined_project| rejoined_project.to_proto())
1384                .collect(),
1385        })?;
1386        room_updated(&rejoined_room.room, &session.peer);
1387
1388        for project in &rejoined_room.reshared_projects {
1389            for collaborator in &project.collaborators {
1390                session
1391                    .peer
1392                    .send(
1393                        collaborator.connection_id,
1394                        proto::UpdateProjectCollaborator {
1395                            project_id: project.id.to_proto(),
1396                            old_peer_id: Some(project.old_connection_id.into()),
1397                            new_peer_id: Some(session.connection_id.into()),
1398                        },
1399                    )
1400                    .trace_err();
1401            }
1402
1403            broadcast(
1404                Some(session.connection_id),
1405                project
1406                    .collaborators
1407                    .iter()
1408                    .map(|collaborator| collaborator.connection_id),
1409                |connection_id| {
1410                    session.peer.forward_send(
1411                        session.connection_id,
1412                        connection_id,
1413                        proto::UpdateProject {
1414                            project_id: project.id.to_proto(),
1415                            worktrees: project.worktrees.clone(),
1416                        },
1417                    )
1418                },
1419            );
1420        }
1421
1422        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1423
1424        let rejoined_room = rejoined_room.into_inner();
1425
1426        room = rejoined_room.room;
1427        channel = rejoined_room.channel;
1428    }
1429
1430    if let Some(channel) = channel {
1431        channel_updated(
1432            &channel,
1433            &room,
1434            &session.peer,
1435            &*session.connection_pool().await,
1436        );
1437    }
1438
1439    update_user_contacts(session.user_id(), &session).await?;
1440    Ok(())
1441}
1442
1443fn notify_rejoined_projects(
1444    rejoined_projects: &mut Vec<RejoinedProject>,
1445    session: &Session,
1446) -> Result<()> {
1447    for project in rejoined_projects.iter() {
1448        for collaborator in &project.collaborators {
1449            session
1450                .peer
1451                .send(
1452                    collaborator.connection_id,
1453                    proto::UpdateProjectCollaborator {
1454                        project_id: project.id.to_proto(),
1455                        old_peer_id: Some(project.old_connection_id.into()),
1456                        new_peer_id: Some(session.connection_id.into()),
1457                    },
1458                )
1459                .trace_err();
1460        }
1461    }
1462
1463    for project in rejoined_projects {
1464        for worktree in mem::take(&mut project.worktrees) {
1465            // Stream this worktree's entries.
1466            let message = proto::UpdateWorktree {
1467                project_id: project.id.to_proto(),
1468                worktree_id: worktree.id,
1469                abs_path: worktree.abs_path.clone(),
1470                root_name: worktree.root_name,
1471                updated_entries: worktree.updated_entries,
1472                removed_entries: worktree.removed_entries,
1473                scan_id: worktree.scan_id,
1474                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1475                updated_repositories: worktree.updated_repositories,
1476                removed_repositories: worktree.removed_repositories,
1477            };
1478            for update in proto::split_worktree_update(message) {
1479                session.peer.send(session.connection_id, update)?;
1480            }
1481
1482            // Stream this worktree's diagnostics.
1483            for summary in worktree.diagnostic_summaries {
1484                session.peer.send(
1485                    session.connection_id,
1486                    proto::UpdateDiagnosticSummary {
1487                        project_id: project.id.to_proto(),
1488                        worktree_id: worktree.id,
1489                        summary: Some(summary),
1490                    },
1491                )?;
1492            }
1493
1494            for settings_file in worktree.settings_files {
1495                session.peer.send(
1496                    session.connection_id,
1497                    proto::UpdateWorktreeSettings {
1498                        project_id: project.id.to_proto(),
1499                        worktree_id: worktree.id,
1500                        path: settings_file.path,
1501                        content: Some(settings_file.content),
1502                        kind: Some(settings_file.kind.to_proto().into()),
1503                    },
1504                )?;
1505            }
1506        }
1507
1508        for repository in mem::take(&mut project.updated_repositories) {
1509            for update in split_repository_update(repository) {
1510                session.peer.send(session.connection_id, update)?;
1511            }
1512        }
1513
1514        for id in mem::take(&mut project.removed_repositories) {
1515            session.peer.send(
1516                session.connection_id,
1517                proto::RemoveRepository {
1518                    project_id: project.id.to_proto(),
1519                    id,
1520                },
1521            )?;
1522        }
1523    }
1524
1525    Ok(())
1526}
1527
1528/// leave room disconnects from the room.
1529async fn leave_room(
1530    _: proto::LeaveRoom,
1531    response: Response<proto::LeaveRoom>,
1532    session: Session,
1533) -> Result<()> {
1534    leave_room_for_session(&session, session.connection_id).await?;
1535    response.send(proto::Ack {})?;
1536    Ok(())
1537}
1538
1539/// Updates the permissions of someone else in the room.
1540async fn set_room_participant_role(
1541    request: proto::SetRoomParticipantRole,
1542    response: Response<proto::SetRoomParticipantRole>,
1543    session: Session,
1544) -> Result<()> {
1545    let user_id = UserId::from_proto(request.user_id);
1546    let role = ChannelRole::from(request.role());
1547
1548    let (livekit_room, can_publish) = {
1549        let room = session
1550            .db()
1551            .await
1552            .set_room_participant_role(
1553                session.user_id(),
1554                RoomId::from_proto(request.room_id),
1555                user_id,
1556                role,
1557            )
1558            .await?;
1559
1560        let livekit_room = room.livekit_room.clone();
1561        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1562        room_updated(&room, &session.peer);
1563        (livekit_room, can_publish)
1564    };
1565
1566    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1567        live_kit
1568            .update_participant(
1569                livekit_room.clone(),
1570                request.user_id.to_string(),
1571                livekit_api::proto::ParticipantPermission {
1572                    can_subscribe: true,
1573                    can_publish,
1574                    can_publish_data: can_publish,
1575                    hidden: false,
1576                    recorder: false,
1577                },
1578            )
1579            .await
1580            .trace_err();
1581    }
1582
1583    response.send(proto::Ack {})?;
1584    Ok(())
1585}
1586
1587/// Call someone else into the current room
1588async fn call(
1589    request: proto::Call,
1590    response: Response<proto::Call>,
1591    session: Session,
1592) -> Result<()> {
1593    let room_id = RoomId::from_proto(request.room_id);
1594    let calling_user_id = session.user_id();
1595    let calling_connection_id = session.connection_id;
1596    let called_user_id = UserId::from_proto(request.called_user_id);
1597    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1598    if !session
1599        .db()
1600        .await
1601        .has_contact(calling_user_id, called_user_id)
1602        .await?
1603    {
1604        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1605    }
1606
1607    let incoming_call = {
1608        let (room, incoming_call) = &mut *session
1609            .db()
1610            .await
1611            .call(
1612                room_id,
1613                calling_user_id,
1614                calling_connection_id,
1615                called_user_id,
1616                initial_project_id,
1617            )
1618            .await?;
1619        room_updated(room, &session.peer);
1620        mem::take(incoming_call)
1621    };
1622    update_user_contacts(called_user_id, &session).await?;
1623
1624    let mut calls = session
1625        .connection_pool()
1626        .await
1627        .user_connection_ids(called_user_id)
1628        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1629        .collect::<FuturesUnordered<_>>();
1630
1631    while let Some(call_response) = calls.next().await {
1632        match call_response.as_ref() {
1633            Ok(_) => {
1634                response.send(proto::Ack {})?;
1635                return Ok(());
1636            }
1637            Err(_) => {
1638                call_response.trace_err();
1639            }
1640        }
1641    }
1642
1643    {
1644        let room = session
1645            .db()
1646            .await
1647            .call_failed(room_id, called_user_id)
1648            .await?;
1649        room_updated(&room, &session.peer);
1650    }
1651    update_user_contacts(called_user_id, &session).await?;
1652
1653    Err(anyhow!("failed to ring user"))?
1654}
1655
1656/// Cancel an outgoing call.
1657async fn cancel_call(
1658    request: proto::CancelCall,
1659    response: Response<proto::CancelCall>,
1660    session: Session,
1661) -> Result<()> {
1662    let called_user_id = UserId::from_proto(request.called_user_id);
1663    let room_id = RoomId::from_proto(request.room_id);
1664    {
1665        let room = session
1666            .db()
1667            .await
1668            .cancel_call(room_id, session.connection_id, called_user_id)
1669            .await?;
1670        room_updated(&room, &session.peer);
1671    }
1672
1673    for connection_id in session
1674        .connection_pool()
1675        .await
1676        .user_connection_ids(called_user_id)
1677    {
1678        session
1679            .peer
1680            .send(
1681                connection_id,
1682                proto::CallCanceled {
1683                    room_id: room_id.to_proto(),
1684                },
1685            )
1686            .trace_err();
1687    }
1688    response.send(proto::Ack {})?;
1689
1690    update_user_contacts(called_user_id, &session).await?;
1691    Ok(())
1692}
1693
1694/// Decline an incoming call.
1695async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1696    let room_id = RoomId::from_proto(message.room_id);
1697    {
1698        let room = session
1699            .db()
1700            .await
1701            .decline_call(Some(room_id), session.user_id())
1702            .await?
1703            .ok_or_else(|| anyhow!("failed to decline call"))?;
1704        room_updated(&room, &session.peer);
1705    }
1706
1707    for connection_id in session
1708        .connection_pool()
1709        .await
1710        .user_connection_ids(session.user_id())
1711    {
1712        session
1713            .peer
1714            .send(
1715                connection_id,
1716                proto::CallCanceled {
1717                    room_id: room_id.to_proto(),
1718                },
1719            )
1720            .trace_err();
1721    }
1722    update_user_contacts(session.user_id(), &session).await?;
1723    Ok(())
1724}
1725
1726/// Updates other participants in the room with your current location.
1727async fn update_participant_location(
1728    request: proto::UpdateParticipantLocation,
1729    response: Response<proto::UpdateParticipantLocation>,
1730    session: Session,
1731) -> Result<()> {
1732    let room_id = RoomId::from_proto(request.room_id);
1733    let location = request
1734        .location
1735        .ok_or_else(|| anyhow!("invalid location"))?;
1736
1737    let db = session.db().await;
1738    let room = db
1739        .update_room_participant_location(room_id, session.connection_id, location)
1740        .await?;
1741
1742    room_updated(&room, &session.peer);
1743    response.send(proto::Ack {})?;
1744    Ok(())
1745}
1746
1747/// Share a project into the room.
1748async fn share_project(
1749    request: proto::ShareProject,
1750    response: Response<proto::ShareProject>,
1751    session: Session,
1752) -> Result<()> {
1753    let (project_id, room) = &*session
1754        .db()
1755        .await
1756        .share_project(
1757            RoomId::from_proto(request.room_id),
1758            session.connection_id,
1759            &request.worktrees,
1760            request.is_ssh_project,
1761        )
1762        .await?;
1763    response.send(proto::ShareProjectResponse {
1764        project_id: project_id.to_proto(),
1765    })?;
1766    room_updated(room, &session.peer);
1767
1768    Ok(())
1769}
1770
1771/// Unshare a project from the room.
1772async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1773    let project_id = ProjectId::from_proto(message.project_id);
1774    unshare_project_internal(project_id, session.connection_id, &session).await
1775}
1776
1777async fn unshare_project_internal(
1778    project_id: ProjectId,
1779    connection_id: ConnectionId,
1780    session: &Session,
1781) -> Result<()> {
1782    let delete = {
1783        let room_guard = session
1784            .db()
1785            .await
1786            .unshare_project(project_id, connection_id)
1787            .await?;
1788
1789        let (delete, room, guest_connection_ids) = &*room_guard;
1790
1791        let message = proto::UnshareProject {
1792            project_id: project_id.to_proto(),
1793        };
1794
1795        broadcast(
1796            Some(connection_id),
1797            guest_connection_ids.iter().copied(),
1798            |conn_id| session.peer.send(conn_id, message.clone()),
1799        );
1800        if let Some(room) = room {
1801            room_updated(room, &session.peer);
1802        }
1803
1804        *delete
1805    };
1806
1807    if delete {
1808        let db = session.db().await;
1809        db.delete_project(project_id).await?;
1810    }
1811
1812    Ok(())
1813}
1814
1815/// Join someone elses shared project.
1816async fn join_project(
1817    request: proto::JoinProject,
1818    response: Response<proto::JoinProject>,
1819    session: Session,
1820) -> Result<()> {
1821    let project_id = ProjectId::from_proto(request.project_id);
1822
1823    tracing::info!(%project_id, "join project");
1824
1825    let db = session.db().await;
1826    let (project, replica_id) = &mut *db
1827        .join_project(project_id, session.connection_id, session.user_id())
1828        .await?;
1829    drop(db);
1830    tracing::info!(%project_id, "join remote project");
1831    join_project_internal(response, session, project, replica_id)
1832}
1833
1834trait JoinProjectInternalResponse {
1835    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1836}
1837impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1838    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1839        Response::<proto::JoinProject>::send(self, result)
1840    }
1841}
1842
1843fn join_project_internal(
1844    response: impl JoinProjectInternalResponse,
1845    session: Session,
1846    project: &mut Project,
1847    replica_id: &ReplicaId,
1848) -> Result<()> {
1849    let collaborators = project
1850        .collaborators
1851        .iter()
1852        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1853        .map(|collaborator| collaborator.to_proto())
1854        .collect::<Vec<_>>();
1855    let project_id = project.id;
1856    let guest_user_id = session.user_id();
1857
1858    let worktrees = project
1859        .worktrees
1860        .iter()
1861        .map(|(id, worktree)| proto::WorktreeMetadata {
1862            id: *id,
1863            root_name: worktree.root_name.clone(),
1864            visible: worktree.visible,
1865            abs_path: worktree.abs_path.clone(),
1866        })
1867        .collect::<Vec<_>>();
1868
1869    let add_project_collaborator = proto::AddProjectCollaborator {
1870        project_id: project_id.to_proto(),
1871        collaborator: Some(proto::Collaborator {
1872            peer_id: Some(session.connection_id.into()),
1873            replica_id: replica_id.0 as u32,
1874            user_id: guest_user_id.to_proto(),
1875            is_host: false,
1876        }),
1877    };
1878
1879    for collaborator in &collaborators {
1880        session
1881            .peer
1882            .send(
1883                collaborator.peer_id.unwrap().into(),
1884                add_project_collaborator.clone(),
1885            )
1886            .trace_err();
1887    }
1888
1889    // First, we send the metadata associated with each worktree.
1890    response.send(proto::JoinProjectResponse {
1891        project_id: project.id.0 as u64,
1892        worktrees: worktrees.clone(),
1893        replica_id: replica_id.0 as u32,
1894        collaborators: collaborators.clone(),
1895        language_servers: project.language_servers.clone(),
1896        role: project.role.into(),
1897    })?;
1898
1899    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1900        // Stream this worktree's entries.
1901        let message = proto::UpdateWorktree {
1902            project_id: project_id.to_proto(),
1903            worktree_id,
1904            abs_path: worktree.abs_path.clone(),
1905            root_name: worktree.root_name,
1906            updated_entries: worktree.entries,
1907            removed_entries: Default::default(),
1908            scan_id: worktree.scan_id,
1909            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1910            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1911            removed_repositories: Default::default(),
1912        };
1913        for update in proto::split_worktree_update(message) {
1914            session.peer.send(session.connection_id, update.clone())?;
1915        }
1916
1917        // Stream this worktree's diagnostics.
1918        for summary in worktree.diagnostic_summaries {
1919            session.peer.send(
1920                session.connection_id,
1921                proto::UpdateDiagnosticSummary {
1922                    project_id: project_id.to_proto(),
1923                    worktree_id: worktree.id,
1924                    summary: Some(summary),
1925                },
1926            )?;
1927        }
1928
1929        for settings_file in worktree.settings_files {
1930            session.peer.send(
1931                session.connection_id,
1932                proto::UpdateWorktreeSettings {
1933                    project_id: project_id.to_proto(),
1934                    worktree_id: worktree.id,
1935                    path: settings_file.path,
1936                    content: Some(settings_file.content),
1937                    kind: Some(settings_file.kind.to_proto() as i32),
1938                },
1939            )?;
1940        }
1941    }
1942
1943    for repository in mem::take(&mut project.repositories) {
1944        for update in split_repository_update(repository) {
1945            session.peer.send(session.connection_id, update)?;
1946        }
1947    }
1948
1949    for language_server in &project.language_servers {
1950        session.peer.send(
1951            session.connection_id,
1952            proto::UpdateLanguageServer {
1953                project_id: project_id.to_proto(),
1954                language_server_id: language_server.id,
1955                variant: Some(
1956                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1957                        proto::LspDiskBasedDiagnosticsUpdated {},
1958                    ),
1959                ),
1960            },
1961        )?;
1962    }
1963
1964    Ok(())
1965}
1966
1967/// Leave someone elses shared project.
1968async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1969    let sender_id = session.connection_id;
1970    let project_id = ProjectId::from_proto(request.project_id);
1971    let db = session.db().await;
1972
1973    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1974    tracing::info!(
1975        %project_id,
1976        "leave project"
1977    );
1978
1979    project_left(project, &session);
1980    if let Some(room) = room {
1981        room_updated(room, &session.peer);
1982    }
1983
1984    Ok(())
1985}
1986
1987/// Updates other participants with changes to the project
1988async fn update_project(
1989    request: proto::UpdateProject,
1990    response: Response<proto::UpdateProject>,
1991    session: Session,
1992) -> Result<()> {
1993    let project_id = ProjectId::from_proto(request.project_id);
1994    let (room, guest_connection_ids) = &*session
1995        .db()
1996        .await
1997        .update_project(project_id, session.connection_id, &request.worktrees)
1998        .await?;
1999    broadcast(
2000        Some(session.connection_id),
2001        guest_connection_ids.iter().copied(),
2002        |connection_id| {
2003            session
2004                .peer
2005                .forward_send(session.connection_id, connection_id, request.clone())
2006        },
2007    );
2008    if let Some(room) = room {
2009        room_updated(room, &session.peer);
2010    }
2011    response.send(proto::Ack {})?;
2012
2013    Ok(())
2014}
2015
2016/// Updates other participants with changes to the worktree
2017async fn update_worktree(
2018    request: proto::UpdateWorktree,
2019    response: Response<proto::UpdateWorktree>,
2020    session: Session,
2021) -> Result<()> {
2022    let guest_connection_ids = session
2023        .db()
2024        .await
2025        .update_worktree(&request, session.connection_id)
2026        .await?;
2027
2028    broadcast(
2029        Some(session.connection_id),
2030        guest_connection_ids.iter().copied(),
2031        |connection_id| {
2032            session
2033                .peer
2034                .forward_send(session.connection_id, connection_id, request.clone())
2035        },
2036    );
2037    response.send(proto::Ack {})?;
2038    Ok(())
2039}
2040
2041async fn update_repository(
2042    request: proto::UpdateRepository,
2043    response: Response<proto::UpdateRepository>,
2044    session: Session,
2045) -> Result<()> {
2046    let guest_connection_ids = session
2047        .db()
2048        .await
2049        .update_repository(&request, session.connection_id)
2050        .await?;
2051
2052    broadcast(
2053        Some(session.connection_id),
2054        guest_connection_ids.iter().copied(),
2055        |connection_id| {
2056            session
2057                .peer
2058                .forward_send(session.connection_id, connection_id, request.clone())
2059        },
2060    );
2061    response.send(proto::Ack {})?;
2062    Ok(())
2063}
2064
2065async fn remove_repository(
2066    request: proto::RemoveRepository,
2067    response: Response<proto::RemoveRepository>,
2068    session: Session,
2069) -> Result<()> {
2070    let guest_connection_ids = session
2071        .db()
2072        .await
2073        .remove_repository(&request, session.connection_id)
2074        .await?;
2075
2076    broadcast(
2077        Some(session.connection_id),
2078        guest_connection_ids.iter().copied(),
2079        |connection_id| {
2080            session
2081                .peer
2082                .forward_send(session.connection_id, connection_id, request.clone())
2083        },
2084    );
2085    response.send(proto::Ack {})?;
2086    Ok(())
2087}
2088
2089/// Updates other participants with changes to the diagnostics
2090async fn update_diagnostic_summary(
2091    message: proto::UpdateDiagnosticSummary,
2092    session: Session,
2093) -> Result<()> {
2094    let guest_connection_ids = session
2095        .db()
2096        .await
2097        .update_diagnostic_summary(&message, session.connection_id)
2098        .await?;
2099
2100    broadcast(
2101        Some(session.connection_id),
2102        guest_connection_ids.iter().copied(),
2103        |connection_id| {
2104            session
2105                .peer
2106                .forward_send(session.connection_id, connection_id, message.clone())
2107        },
2108    );
2109
2110    Ok(())
2111}
2112
2113/// Updates other participants with changes to the worktree settings
2114async fn update_worktree_settings(
2115    message: proto::UpdateWorktreeSettings,
2116    session: Session,
2117) -> Result<()> {
2118    let guest_connection_ids = session
2119        .db()
2120        .await
2121        .update_worktree_settings(&message, session.connection_id)
2122        .await?;
2123
2124    broadcast(
2125        Some(session.connection_id),
2126        guest_connection_ids.iter().copied(),
2127        |connection_id| {
2128            session
2129                .peer
2130                .forward_send(session.connection_id, connection_id, message.clone())
2131        },
2132    );
2133
2134    Ok(())
2135}
2136
2137/// Notify other participants that a language server has started.
2138async fn start_language_server(
2139    request: proto::StartLanguageServer,
2140    session: Session,
2141) -> Result<()> {
2142    let guest_connection_ids = session
2143        .db()
2144        .await
2145        .start_language_server(&request, session.connection_id)
2146        .await?;
2147
2148    broadcast(
2149        Some(session.connection_id),
2150        guest_connection_ids.iter().copied(),
2151        |connection_id| {
2152            session
2153                .peer
2154                .forward_send(session.connection_id, connection_id, request.clone())
2155        },
2156    );
2157    Ok(())
2158}
2159
2160/// Notify other participants that a language server has changed.
2161async fn update_language_server(
2162    request: proto::UpdateLanguageServer,
2163    session: Session,
2164) -> Result<()> {
2165    let project_id = ProjectId::from_proto(request.project_id);
2166    let project_connection_ids = session
2167        .db()
2168        .await
2169        .project_connection_ids(project_id, session.connection_id, true)
2170        .await?;
2171    broadcast(
2172        Some(session.connection_id),
2173        project_connection_ids.iter().copied(),
2174        |connection_id| {
2175            session
2176                .peer
2177                .forward_send(session.connection_id, connection_id, request.clone())
2178        },
2179    );
2180    Ok(())
2181}
2182
2183/// forward a project request to the host. These requests should be read only
2184/// as guests are allowed to send them.
2185async fn forward_read_only_project_request<T>(
2186    request: T,
2187    response: Response<T>,
2188    session: Session,
2189) -> Result<()>
2190where
2191    T: EntityMessage + RequestMessage,
2192{
2193    let project_id = ProjectId::from_proto(request.remote_entity_id());
2194    let host_connection_id = session
2195        .db()
2196        .await
2197        .host_for_read_only_project_request(project_id, session.connection_id)
2198        .await?;
2199    let payload = session
2200        .peer
2201        .forward_request(session.connection_id, host_connection_id, request)
2202        .await?;
2203    response.send(payload)?;
2204    Ok(())
2205}
2206
2207async fn forward_find_search_candidates_request(
2208    request: proto::FindSearchCandidates,
2209    response: Response<proto::FindSearchCandidates>,
2210    session: Session,
2211) -> Result<()> {
2212    let project_id = ProjectId::from_proto(request.remote_entity_id());
2213    let host_connection_id = session
2214        .db()
2215        .await
2216        .host_for_read_only_project_request(project_id, session.connection_id)
2217        .await?;
2218    let payload = session
2219        .peer
2220        .forward_request(session.connection_id, host_connection_id, request)
2221        .await?;
2222    response.send(payload)?;
2223    Ok(())
2224}
2225
2226/// forward a project request to the host. These requests are disallowed
2227/// for guests.
2228async fn forward_mutating_project_request<T>(
2229    request: T,
2230    response: Response<T>,
2231    session: Session,
2232) -> Result<()>
2233where
2234    T: EntityMessage + RequestMessage,
2235{
2236    let project_id = ProjectId::from_proto(request.remote_entity_id());
2237
2238    let host_connection_id = session
2239        .db()
2240        .await
2241        .host_for_mutating_project_request(project_id, session.connection_id)
2242        .await?;
2243    let payload = session
2244        .peer
2245        .forward_request(session.connection_id, host_connection_id, request)
2246        .await?;
2247    response.send(payload)?;
2248    Ok(())
2249}
2250
2251/// Notify other participants that a new buffer has been created
2252async fn create_buffer_for_peer(
2253    request: proto::CreateBufferForPeer,
2254    session: Session,
2255) -> Result<()> {
2256    session
2257        .db()
2258        .await
2259        .check_user_is_project_host(
2260            ProjectId::from_proto(request.project_id),
2261            session.connection_id,
2262        )
2263        .await?;
2264    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2265    session
2266        .peer
2267        .forward_send(session.connection_id, peer_id.into(), request)?;
2268    Ok(())
2269}
2270
2271/// Notify other participants that a buffer has been updated. This is
2272/// allowed for guests as long as the update is limited to selections.
2273async fn update_buffer(
2274    request: proto::UpdateBuffer,
2275    response: Response<proto::UpdateBuffer>,
2276    session: Session,
2277) -> Result<()> {
2278    let project_id = ProjectId::from_proto(request.project_id);
2279    let mut capability = Capability::ReadOnly;
2280
2281    for op in request.operations.iter() {
2282        match op.variant {
2283            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2284            Some(_) => capability = Capability::ReadWrite,
2285        }
2286    }
2287
2288    let host = {
2289        let guard = session
2290            .db()
2291            .await
2292            .connections_for_buffer_update(project_id, session.connection_id, capability)
2293            .await?;
2294
2295        let (host, guests) = &*guard;
2296
2297        broadcast(
2298            Some(session.connection_id),
2299            guests.clone(),
2300            |connection_id| {
2301                session
2302                    .peer
2303                    .forward_send(session.connection_id, connection_id, request.clone())
2304            },
2305        );
2306
2307        *host
2308    };
2309
2310    if host != session.connection_id {
2311        session
2312            .peer
2313            .forward_request(session.connection_id, host, request.clone())
2314            .await?;
2315    }
2316
2317    response.send(proto::Ack {})?;
2318    Ok(())
2319}
2320
2321async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2322    let project_id = ProjectId::from_proto(message.project_id);
2323
2324    let operation = message.operation.as_ref().context("invalid operation")?;
2325    let capability = match operation.variant.as_ref() {
2326        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2327            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2328                match buffer_op.variant {
2329                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2330                        Capability::ReadOnly
2331                    }
2332                    _ => Capability::ReadWrite,
2333                }
2334            } else {
2335                Capability::ReadWrite
2336            }
2337        }
2338        Some(_) => Capability::ReadWrite,
2339        None => Capability::ReadOnly,
2340    };
2341
2342    let guard = session
2343        .db()
2344        .await
2345        .connections_for_buffer_update(project_id, session.connection_id, capability)
2346        .await?;
2347
2348    let (host, guests) = &*guard;
2349
2350    broadcast(
2351        Some(session.connection_id),
2352        guests.iter().chain([host]).copied(),
2353        |connection_id| {
2354            session
2355                .peer
2356                .forward_send(session.connection_id, connection_id, message.clone())
2357        },
2358    );
2359
2360    Ok(())
2361}
2362
2363/// Notify other participants that a project has been updated.
2364async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2365    request: T,
2366    session: Session,
2367) -> Result<()> {
2368    let project_id = ProjectId::from_proto(request.remote_entity_id());
2369    let project_connection_ids = session
2370        .db()
2371        .await
2372        .project_connection_ids(project_id, session.connection_id, false)
2373        .await?;
2374
2375    broadcast(
2376        Some(session.connection_id),
2377        project_connection_ids.iter().copied(),
2378        |connection_id| {
2379            session
2380                .peer
2381                .forward_send(session.connection_id, connection_id, request.clone())
2382        },
2383    );
2384    Ok(())
2385}
2386
2387/// Start following another user in a call.
2388async fn follow(
2389    request: proto::Follow,
2390    response: Response<proto::Follow>,
2391    session: Session,
2392) -> Result<()> {
2393    let room_id = RoomId::from_proto(request.room_id);
2394    let project_id = request.project_id.map(ProjectId::from_proto);
2395    let leader_id = request
2396        .leader_id
2397        .ok_or_else(|| anyhow!("invalid leader id"))?
2398        .into();
2399    let follower_id = session.connection_id;
2400
2401    session
2402        .db()
2403        .await
2404        .check_room_participants(room_id, leader_id, session.connection_id)
2405        .await?;
2406
2407    let response_payload = session
2408        .peer
2409        .forward_request(session.connection_id, leader_id, request)
2410        .await?;
2411    response.send(response_payload)?;
2412
2413    if let Some(project_id) = project_id {
2414        let room = session
2415            .db()
2416            .await
2417            .follow(room_id, project_id, leader_id, follower_id)
2418            .await?;
2419        room_updated(&room, &session.peer);
2420    }
2421
2422    Ok(())
2423}
2424
2425/// Stop following another user in a call.
2426async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2427    let room_id = RoomId::from_proto(request.room_id);
2428    let project_id = request.project_id.map(ProjectId::from_proto);
2429    let leader_id = request
2430        .leader_id
2431        .ok_or_else(|| anyhow!("invalid leader id"))?
2432        .into();
2433    let follower_id = session.connection_id;
2434
2435    session
2436        .db()
2437        .await
2438        .check_room_participants(room_id, leader_id, session.connection_id)
2439        .await?;
2440
2441    session
2442        .peer
2443        .forward_send(session.connection_id, leader_id, request)?;
2444
2445    if let Some(project_id) = project_id {
2446        let room = session
2447            .db()
2448            .await
2449            .unfollow(room_id, project_id, leader_id, follower_id)
2450            .await?;
2451        room_updated(&room, &session.peer);
2452    }
2453
2454    Ok(())
2455}
2456
2457/// Notify everyone following you of your current location.
2458async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2459    let room_id = RoomId::from_proto(request.room_id);
2460    let database = session.db.lock().await;
2461
2462    let connection_ids = if let Some(project_id) = request.project_id {
2463        let project_id = ProjectId::from_proto(project_id);
2464        database
2465            .project_connection_ids(project_id, session.connection_id, true)
2466            .await?
2467    } else {
2468        database
2469            .room_connection_ids(room_id, session.connection_id)
2470            .await?
2471    };
2472
2473    // For now, don't send view update messages back to that view's current leader.
2474    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2475        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2476        _ => None,
2477    });
2478
2479    for connection_id in connection_ids.iter().cloned() {
2480        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2481            session
2482                .peer
2483                .forward_send(session.connection_id, connection_id, request.clone())?;
2484        }
2485    }
2486    Ok(())
2487}
2488
2489/// Get public data about users.
2490async fn get_users(
2491    request: proto::GetUsers,
2492    response: Response<proto::GetUsers>,
2493    session: Session,
2494) -> Result<()> {
2495    let user_ids = request
2496        .user_ids
2497        .into_iter()
2498        .map(UserId::from_proto)
2499        .collect();
2500    let users = session
2501        .db()
2502        .await
2503        .get_users_by_ids(user_ids)
2504        .await?
2505        .into_iter()
2506        .map(|user| proto::User {
2507            id: user.id.to_proto(),
2508            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2509            github_login: user.github_login,
2510            email: user.email_address,
2511            name: user.name,
2512        })
2513        .collect();
2514    response.send(proto::UsersResponse { users })?;
2515    Ok(())
2516}
2517
2518/// Search for users (to invite) buy Github login
2519async fn fuzzy_search_users(
2520    request: proto::FuzzySearchUsers,
2521    response: Response<proto::FuzzySearchUsers>,
2522    session: Session,
2523) -> Result<()> {
2524    let query = request.query;
2525    let users = match query.len() {
2526        0 => vec![],
2527        1 | 2 => session
2528            .db()
2529            .await
2530            .get_user_by_github_login(&query)
2531            .await?
2532            .into_iter()
2533            .collect(),
2534        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2535    };
2536    let users = users
2537        .into_iter()
2538        .filter(|user| user.id != session.user_id())
2539        .map(|user| proto::User {
2540            id: user.id.to_proto(),
2541            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2542            github_login: user.github_login,
2543            name: user.name,
2544            email: user.email_address,
2545        })
2546        .collect();
2547    response.send(proto::UsersResponse { users })?;
2548    Ok(())
2549}
2550
2551/// Send a contact request to another user.
2552async fn request_contact(
2553    request: proto::RequestContact,
2554    response: Response<proto::RequestContact>,
2555    session: Session,
2556) -> Result<()> {
2557    let requester_id = session.user_id();
2558    let responder_id = UserId::from_proto(request.responder_id);
2559    if requester_id == responder_id {
2560        return Err(anyhow!("cannot add yourself as a contact"))?;
2561    }
2562
2563    let notifications = session
2564        .db()
2565        .await
2566        .send_contact_request(requester_id, responder_id)
2567        .await?;
2568
2569    // Update outgoing contact requests of requester
2570    let mut update = proto::UpdateContacts::default();
2571    update.outgoing_requests.push(responder_id.to_proto());
2572    for connection_id in session
2573        .connection_pool()
2574        .await
2575        .user_connection_ids(requester_id)
2576    {
2577        session.peer.send(connection_id, update.clone())?;
2578    }
2579
2580    // Update incoming contact requests of responder
2581    let mut update = proto::UpdateContacts::default();
2582    update
2583        .incoming_requests
2584        .push(proto::IncomingContactRequest {
2585            requester_id: requester_id.to_proto(),
2586        });
2587    let connection_pool = session.connection_pool().await;
2588    for connection_id in connection_pool.user_connection_ids(responder_id) {
2589        session.peer.send(connection_id, update.clone())?;
2590    }
2591
2592    send_notifications(&connection_pool, &session.peer, notifications);
2593
2594    response.send(proto::Ack {})?;
2595    Ok(())
2596}
2597
2598/// Accept or decline a contact request
2599async fn respond_to_contact_request(
2600    request: proto::RespondToContactRequest,
2601    response: Response<proto::RespondToContactRequest>,
2602    session: Session,
2603) -> Result<()> {
2604    let responder_id = session.user_id();
2605    let requester_id = UserId::from_proto(request.requester_id);
2606    let db = session.db().await;
2607    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2608        db.dismiss_contact_notification(responder_id, requester_id)
2609            .await?;
2610    } else {
2611        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2612
2613        let notifications = db
2614            .respond_to_contact_request(responder_id, requester_id, accept)
2615            .await?;
2616        let requester_busy = db.is_user_busy(requester_id).await?;
2617        let responder_busy = db.is_user_busy(responder_id).await?;
2618
2619        let pool = session.connection_pool().await;
2620        // Update responder with new contact
2621        let mut update = proto::UpdateContacts::default();
2622        if accept {
2623            update
2624                .contacts
2625                .push(contact_for_user(requester_id, requester_busy, &pool));
2626        }
2627        update
2628            .remove_incoming_requests
2629            .push(requester_id.to_proto());
2630        for connection_id in pool.user_connection_ids(responder_id) {
2631            session.peer.send(connection_id, update.clone())?;
2632        }
2633
2634        // Update requester with new contact
2635        let mut update = proto::UpdateContacts::default();
2636        if accept {
2637            update
2638                .contacts
2639                .push(contact_for_user(responder_id, responder_busy, &pool));
2640        }
2641        update
2642            .remove_outgoing_requests
2643            .push(responder_id.to_proto());
2644
2645        for connection_id in pool.user_connection_ids(requester_id) {
2646            session.peer.send(connection_id, update.clone())?;
2647        }
2648
2649        send_notifications(&pool, &session.peer, notifications);
2650    }
2651
2652    response.send(proto::Ack {})?;
2653    Ok(())
2654}
2655
2656/// Remove a contact.
2657async fn remove_contact(
2658    request: proto::RemoveContact,
2659    response: Response<proto::RemoveContact>,
2660    session: Session,
2661) -> Result<()> {
2662    let requester_id = session.user_id();
2663    let responder_id = UserId::from_proto(request.user_id);
2664    let db = session.db().await;
2665    let (contact_accepted, deleted_notification_id) =
2666        db.remove_contact(requester_id, responder_id).await?;
2667
2668    let pool = session.connection_pool().await;
2669    // Update outgoing contact requests of requester
2670    let mut update = proto::UpdateContacts::default();
2671    if contact_accepted {
2672        update.remove_contacts.push(responder_id.to_proto());
2673    } else {
2674        update
2675            .remove_outgoing_requests
2676            .push(responder_id.to_proto());
2677    }
2678    for connection_id in pool.user_connection_ids(requester_id) {
2679        session.peer.send(connection_id, update.clone())?;
2680    }
2681
2682    // Update incoming contact requests of responder
2683    let mut update = proto::UpdateContacts::default();
2684    if contact_accepted {
2685        update.remove_contacts.push(requester_id.to_proto());
2686    } else {
2687        update
2688            .remove_incoming_requests
2689            .push(requester_id.to_proto());
2690    }
2691    for connection_id in pool.user_connection_ids(responder_id) {
2692        session.peer.send(connection_id, update.clone())?;
2693        if let Some(notification_id) = deleted_notification_id {
2694            session.peer.send(
2695                connection_id,
2696                proto::DeleteNotification {
2697                    notification_id: notification_id.to_proto(),
2698                },
2699            )?;
2700        }
2701    }
2702
2703    response.send(proto::Ack {})?;
2704    Ok(())
2705}
2706
2707fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2708    version.0.minor() < 139
2709}
2710
2711async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2712    let plan = session.current_plan(&session.db().await).await?;
2713
2714    session
2715        .peer
2716        .send(
2717            session.connection_id,
2718            proto::UpdateUserPlan { plan: plan.into() },
2719        )
2720        .trace_err();
2721
2722    Ok(())
2723}
2724
2725async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2726    subscribe_user_to_channels(session.user_id(), &session).await?;
2727    Ok(())
2728}
2729
2730async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2731    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2732    let mut pool = session.connection_pool().await;
2733    for membership in &channels_for_user.channel_memberships {
2734        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2735    }
2736    session.peer.send(
2737        session.connection_id,
2738        build_update_user_channels(&channels_for_user),
2739    )?;
2740    session.peer.send(
2741        session.connection_id,
2742        build_channels_update(channels_for_user),
2743    )?;
2744    Ok(())
2745}
2746
2747/// Creates a new channel.
2748async fn create_channel(
2749    request: proto::CreateChannel,
2750    response: Response<proto::CreateChannel>,
2751    session: Session,
2752) -> Result<()> {
2753    let db = session.db().await;
2754
2755    let parent_id = request.parent_id.map(ChannelId::from_proto);
2756    let (channel, membership) = db
2757        .create_channel(&request.name, parent_id, session.user_id())
2758        .await?;
2759
2760    let root_id = channel.root_id();
2761    let channel = Channel::from_model(channel);
2762
2763    response.send(proto::CreateChannelResponse {
2764        channel: Some(channel.to_proto()),
2765        parent_id: request.parent_id,
2766    })?;
2767
2768    let mut connection_pool = session.connection_pool().await;
2769    if let Some(membership) = membership {
2770        connection_pool.subscribe_to_channel(
2771            membership.user_id,
2772            membership.channel_id,
2773            membership.role,
2774        );
2775        let update = proto::UpdateUserChannels {
2776            channel_memberships: vec![proto::ChannelMembership {
2777                channel_id: membership.channel_id.to_proto(),
2778                role: membership.role.into(),
2779            }],
2780            ..Default::default()
2781        };
2782        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2783            session.peer.send(connection_id, update.clone())?;
2784        }
2785    }
2786
2787    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2788        if !role.can_see_channel(channel.visibility) {
2789            continue;
2790        }
2791
2792        let update = proto::UpdateChannels {
2793            channels: vec![channel.to_proto()],
2794            ..Default::default()
2795        };
2796        session.peer.send(connection_id, update.clone())?;
2797    }
2798
2799    Ok(())
2800}
2801
2802/// Delete a channel
2803async fn delete_channel(
2804    request: proto::DeleteChannel,
2805    response: Response<proto::DeleteChannel>,
2806    session: Session,
2807) -> Result<()> {
2808    let db = session.db().await;
2809
2810    let channel_id = request.channel_id;
2811    let (root_channel, removed_channels) = db
2812        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2813        .await?;
2814    response.send(proto::Ack {})?;
2815
2816    // Notify members of removed channels
2817    let mut update = proto::UpdateChannels::default();
2818    update
2819        .delete_channels
2820        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2821
2822    let connection_pool = session.connection_pool().await;
2823    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2824        session.peer.send(connection_id, update.clone())?;
2825    }
2826
2827    Ok(())
2828}
2829
2830/// Invite someone to join a channel.
2831async fn invite_channel_member(
2832    request: proto::InviteChannelMember,
2833    response: Response<proto::InviteChannelMember>,
2834    session: Session,
2835) -> Result<()> {
2836    let db = session.db().await;
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838    let invitee_id = UserId::from_proto(request.user_id);
2839    let InviteMemberResult {
2840        channel,
2841        notifications,
2842    } = db
2843        .invite_channel_member(
2844            channel_id,
2845            invitee_id,
2846            session.user_id(),
2847            request.role().into(),
2848        )
2849        .await?;
2850
2851    let update = proto::UpdateChannels {
2852        channel_invitations: vec![channel.to_proto()],
2853        ..Default::default()
2854    };
2855
2856    let connection_pool = session.connection_pool().await;
2857    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2858        session.peer.send(connection_id, update.clone())?;
2859    }
2860
2861    send_notifications(&connection_pool, &session.peer, notifications);
2862
2863    response.send(proto::Ack {})?;
2864    Ok(())
2865}
2866
2867/// remove someone from a channel
2868async fn remove_channel_member(
2869    request: proto::RemoveChannelMember,
2870    response: Response<proto::RemoveChannelMember>,
2871    session: Session,
2872) -> Result<()> {
2873    let db = session.db().await;
2874    let channel_id = ChannelId::from_proto(request.channel_id);
2875    let member_id = UserId::from_proto(request.user_id);
2876
2877    let RemoveChannelMemberResult {
2878        membership_update,
2879        notification_id,
2880    } = db
2881        .remove_channel_member(channel_id, member_id, session.user_id())
2882        .await?;
2883
2884    let mut connection_pool = session.connection_pool().await;
2885    notify_membership_updated(
2886        &mut connection_pool,
2887        membership_update,
2888        member_id,
2889        &session.peer,
2890    );
2891    for connection_id in connection_pool.user_connection_ids(member_id) {
2892        if let Some(notification_id) = notification_id {
2893            session
2894                .peer
2895                .send(
2896                    connection_id,
2897                    proto::DeleteNotification {
2898                        notification_id: notification_id.to_proto(),
2899                    },
2900                )
2901                .trace_err();
2902        }
2903    }
2904
2905    response.send(proto::Ack {})?;
2906    Ok(())
2907}
2908
2909/// Toggle the channel between public and private.
2910/// Care is taken to maintain the invariant that public channels only descend from public channels,
2911/// (though members-only channels can appear at any point in the hierarchy).
2912async fn set_channel_visibility(
2913    request: proto::SetChannelVisibility,
2914    response: Response<proto::SetChannelVisibility>,
2915    session: Session,
2916) -> Result<()> {
2917    let db = session.db().await;
2918    let channel_id = ChannelId::from_proto(request.channel_id);
2919    let visibility = request.visibility().into();
2920
2921    let channel_model = db
2922        .set_channel_visibility(channel_id, visibility, session.user_id())
2923        .await?;
2924    let root_id = channel_model.root_id();
2925    let channel = Channel::from_model(channel_model);
2926
2927    let mut connection_pool = session.connection_pool().await;
2928    for (user_id, role) in connection_pool
2929        .channel_user_ids(root_id)
2930        .collect::<Vec<_>>()
2931        .into_iter()
2932    {
2933        let update = if role.can_see_channel(channel.visibility) {
2934            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2935            proto::UpdateChannels {
2936                channels: vec![channel.to_proto()],
2937                ..Default::default()
2938            }
2939        } else {
2940            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2941            proto::UpdateChannels {
2942                delete_channels: vec![channel.id.to_proto()],
2943                ..Default::default()
2944            }
2945        };
2946
2947        for connection_id in connection_pool.user_connection_ids(user_id) {
2948            session.peer.send(connection_id, update.clone())?;
2949        }
2950    }
2951
2952    response.send(proto::Ack {})?;
2953    Ok(())
2954}
2955
2956/// Alter the role for a user in the channel.
2957async fn set_channel_member_role(
2958    request: proto::SetChannelMemberRole,
2959    response: Response<proto::SetChannelMemberRole>,
2960    session: Session,
2961) -> Result<()> {
2962    let db = session.db().await;
2963    let channel_id = ChannelId::from_proto(request.channel_id);
2964    let member_id = UserId::from_proto(request.user_id);
2965    let result = db
2966        .set_channel_member_role(
2967            channel_id,
2968            session.user_id(),
2969            member_id,
2970            request.role().into(),
2971        )
2972        .await?;
2973
2974    match result {
2975        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2976            let mut connection_pool = session.connection_pool().await;
2977            notify_membership_updated(
2978                &mut connection_pool,
2979                membership_update,
2980                member_id,
2981                &session.peer,
2982            )
2983        }
2984        db::SetMemberRoleResult::InviteUpdated(channel) => {
2985            let update = proto::UpdateChannels {
2986                channel_invitations: vec![channel.to_proto()],
2987                ..Default::default()
2988            };
2989
2990            for connection_id in session
2991                .connection_pool()
2992                .await
2993                .user_connection_ids(member_id)
2994            {
2995                session.peer.send(connection_id, update.clone())?;
2996            }
2997        }
2998    }
2999
3000    response.send(proto::Ack {})?;
3001    Ok(())
3002}
3003
3004/// Change the name of a channel
3005async fn rename_channel(
3006    request: proto::RenameChannel,
3007    response: Response<proto::RenameChannel>,
3008    session: Session,
3009) -> Result<()> {
3010    let db = session.db().await;
3011    let channel_id = ChannelId::from_proto(request.channel_id);
3012    let channel_model = db
3013        .rename_channel(channel_id, session.user_id(), &request.name)
3014        .await?;
3015    let root_id = channel_model.root_id();
3016    let channel = Channel::from_model(channel_model);
3017
3018    response.send(proto::RenameChannelResponse {
3019        channel: Some(channel.to_proto()),
3020    })?;
3021
3022    let connection_pool = session.connection_pool().await;
3023    let update = proto::UpdateChannels {
3024        channels: vec![channel.to_proto()],
3025        ..Default::default()
3026    };
3027    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3028        if role.can_see_channel(channel.visibility) {
3029            session.peer.send(connection_id, update.clone())?;
3030        }
3031    }
3032
3033    Ok(())
3034}
3035
3036/// Move a channel to a new parent.
3037async fn move_channel(
3038    request: proto::MoveChannel,
3039    response: Response<proto::MoveChannel>,
3040    session: Session,
3041) -> Result<()> {
3042    let channel_id = ChannelId::from_proto(request.channel_id);
3043    let to = ChannelId::from_proto(request.to);
3044
3045    let (root_id, channels) = session
3046        .db()
3047        .await
3048        .move_channel(channel_id, to, session.user_id())
3049        .await?;
3050
3051    let connection_pool = session.connection_pool().await;
3052    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3053        let channels = channels
3054            .iter()
3055            .filter_map(|channel| {
3056                if role.can_see_channel(channel.visibility) {
3057                    Some(channel.to_proto())
3058                } else {
3059                    None
3060                }
3061            })
3062            .collect::<Vec<_>>();
3063        if channels.is_empty() {
3064            continue;
3065        }
3066
3067        let update = proto::UpdateChannels {
3068            channels,
3069            ..Default::default()
3070        };
3071
3072        session.peer.send(connection_id, update.clone())?;
3073    }
3074
3075    response.send(Ack {})?;
3076    Ok(())
3077}
3078
3079/// Get the list of channel members
3080async fn get_channel_members(
3081    request: proto::GetChannelMembers,
3082    response: Response<proto::GetChannelMembers>,
3083    session: Session,
3084) -> Result<()> {
3085    let db = session.db().await;
3086    let channel_id = ChannelId::from_proto(request.channel_id);
3087    let limit = if request.limit == 0 {
3088        u16::MAX as u64
3089    } else {
3090        request.limit
3091    };
3092    let (members, users) = db
3093        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3094        .await?;
3095    response.send(proto::GetChannelMembersResponse { members, users })?;
3096    Ok(())
3097}
3098
3099/// Accept or decline a channel invitation.
3100async fn respond_to_channel_invite(
3101    request: proto::RespondToChannelInvite,
3102    response: Response<proto::RespondToChannelInvite>,
3103    session: Session,
3104) -> Result<()> {
3105    let db = session.db().await;
3106    let channel_id = ChannelId::from_proto(request.channel_id);
3107    let RespondToChannelInvite {
3108        membership_update,
3109        notifications,
3110    } = db
3111        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3112        .await?;
3113
3114    let mut connection_pool = session.connection_pool().await;
3115    if let Some(membership_update) = membership_update {
3116        notify_membership_updated(
3117            &mut connection_pool,
3118            membership_update,
3119            session.user_id(),
3120            &session.peer,
3121        );
3122    } else {
3123        let update = proto::UpdateChannels {
3124            remove_channel_invitations: vec![channel_id.to_proto()],
3125            ..Default::default()
3126        };
3127
3128        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3129            session.peer.send(connection_id, update.clone())?;
3130        }
3131    };
3132
3133    send_notifications(&connection_pool, &session.peer, notifications);
3134
3135    response.send(proto::Ack {})?;
3136
3137    Ok(())
3138}
3139
3140/// Join the channels' room
3141async fn join_channel(
3142    request: proto::JoinChannel,
3143    response: Response<proto::JoinChannel>,
3144    session: Session,
3145) -> Result<()> {
3146    let channel_id = ChannelId::from_proto(request.channel_id);
3147    join_channel_internal(channel_id, Box::new(response), session).await
3148}
3149
3150trait JoinChannelInternalResponse {
3151    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3152}
3153impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3154    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3155        Response::<proto::JoinChannel>::send(self, result)
3156    }
3157}
3158impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3159    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3160        Response::<proto::JoinRoom>::send(self, result)
3161    }
3162}
3163
3164async fn join_channel_internal(
3165    channel_id: ChannelId,
3166    response: Box<impl JoinChannelInternalResponse>,
3167    session: Session,
3168) -> Result<()> {
3169    let joined_room = {
3170        let mut db = session.db().await;
3171        // If zed quits without leaving the room, and the user re-opens zed before the
3172        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3173        // room they were in.
3174        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3175            tracing::info!(
3176                stale_connection_id = %connection,
3177                "cleaning up stale connection",
3178            );
3179            drop(db);
3180            leave_room_for_session(&session, connection).await?;
3181            db = session.db().await;
3182        }
3183
3184        let (joined_room, membership_updated, role) = db
3185            .join_channel(channel_id, session.user_id(), session.connection_id)
3186            .await?;
3187
3188        let live_kit_connection_info =
3189            session
3190                .app_state
3191                .livekit_client
3192                .as_ref()
3193                .and_then(|live_kit| {
3194                    let (can_publish, token) = if role == ChannelRole::Guest {
3195                        (
3196                            false,
3197                            live_kit
3198                                .guest_token(
3199                                    &joined_room.room.livekit_room,
3200                                    &session.user_id().to_string(),
3201                                )
3202                                .trace_err()?,
3203                        )
3204                    } else {
3205                        (
3206                            true,
3207                            live_kit
3208                                .room_token(
3209                                    &joined_room.room.livekit_room,
3210                                    &session.user_id().to_string(),
3211                                )
3212                                .trace_err()?,
3213                        )
3214                    };
3215
3216                    Some(LiveKitConnectionInfo {
3217                        server_url: live_kit.url().into(),
3218                        token,
3219                        can_publish,
3220                    })
3221                });
3222
3223        response.send(proto::JoinRoomResponse {
3224            room: Some(joined_room.room.clone()),
3225            channel_id: joined_room
3226                .channel
3227                .as_ref()
3228                .map(|channel| channel.id.to_proto()),
3229            live_kit_connection_info,
3230        })?;
3231
3232        let mut connection_pool = session.connection_pool().await;
3233        if let Some(membership_updated) = membership_updated {
3234            notify_membership_updated(
3235                &mut connection_pool,
3236                membership_updated,
3237                session.user_id(),
3238                &session.peer,
3239            );
3240        }
3241
3242        room_updated(&joined_room.room, &session.peer);
3243
3244        joined_room
3245    };
3246
3247    channel_updated(
3248        &joined_room
3249            .channel
3250            .ok_or_else(|| anyhow!("channel not returned"))?,
3251        &joined_room.room,
3252        &session.peer,
3253        &*session.connection_pool().await,
3254    );
3255
3256    update_user_contacts(session.user_id(), &session).await?;
3257    Ok(())
3258}
3259
3260/// Start editing the channel notes
3261async fn join_channel_buffer(
3262    request: proto::JoinChannelBuffer,
3263    response: Response<proto::JoinChannelBuffer>,
3264    session: Session,
3265) -> Result<()> {
3266    let db = session.db().await;
3267    let channel_id = ChannelId::from_proto(request.channel_id);
3268
3269    let open_response = db
3270        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3271        .await?;
3272
3273    let collaborators = open_response.collaborators.clone();
3274    response.send(open_response)?;
3275
3276    let update = UpdateChannelBufferCollaborators {
3277        channel_id: channel_id.to_proto(),
3278        collaborators: collaborators.clone(),
3279    };
3280    channel_buffer_updated(
3281        session.connection_id,
3282        collaborators
3283            .iter()
3284            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3285        &update,
3286        &session.peer,
3287    );
3288
3289    Ok(())
3290}
3291
3292/// Edit the channel notes
3293async fn update_channel_buffer(
3294    request: proto::UpdateChannelBuffer,
3295    session: Session,
3296) -> Result<()> {
3297    let db = session.db().await;
3298    let channel_id = ChannelId::from_proto(request.channel_id);
3299
3300    let (collaborators, epoch, version) = db
3301        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3302        .await?;
3303
3304    channel_buffer_updated(
3305        session.connection_id,
3306        collaborators.clone(),
3307        &proto::UpdateChannelBuffer {
3308            channel_id: channel_id.to_proto(),
3309            operations: request.operations,
3310        },
3311        &session.peer,
3312    );
3313
3314    let pool = &*session.connection_pool().await;
3315
3316    let non_collaborators =
3317        pool.channel_connection_ids(channel_id)
3318            .filter_map(|(connection_id, _)| {
3319                if collaborators.contains(&connection_id) {
3320                    None
3321                } else {
3322                    Some(connection_id)
3323                }
3324            });
3325
3326    broadcast(None, non_collaborators, |peer_id| {
3327        session.peer.send(
3328            peer_id,
3329            proto::UpdateChannels {
3330                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3331                    channel_id: channel_id.to_proto(),
3332                    epoch: epoch as u64,
3333                    version: version.clone(),
3334                }],
3335                ..Default::default()
3336            },
3337        )
3338    });
3339
3340    Ok(())
3341}
3342
3343/// Rejoin the channel notes after a connection blip
3344async fn rejoin_channel_buffers(
3345    request: proto::RejoinChannelBuffers,
3346    response: Response<proto::RejoinChannelBuffers>,
3347    session: Session,
3348) -> Result<()> {
3349    let db = session.db().await;
3350    let buffers = db
3351        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3352        .await?;
3353
3354    for rejoined_buffer in &buffers {
3355        let collaborators_to_notify = rejoined_buffer
3356            .buffer
3357            .collaborators
3358            .iter()
3359            .filter_map(|c| Some(c.peer_id?.into()));
3360        channel_buffer_updated(
3361            session.connection_id,
3362            collaborators_to_notify,
3363            &proto::UpdateChannelBufferCollaborators {
3364                channel_id: rejoined_buffer.buffer.channel_id,
3365                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3366            },
3367            &session.peer,
3368        );
3369    }
3370
3371    response.send(proto::RejoinChannelBuffersResponse {
3372        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3373    })?;
3374
3375    Ok(())
3376}
3377
3378/// Stop editing the channel notes
3379async fn leave_channel_buffer(
3380    request: proto::LeaveChannelBuffer,
3381    response: Response<proto::LeaveChannelBuffer>,
3382    session: Session,
3383) -> Result<()> {
3384    let db = session.db().await;
3385    let channel_id = ChannelId::from_proto(request.channel_id);
3386
3387    let left_buffer = db
3388        .leave_channel_buffer(channel_id, session.connection_id)
3389        .await?;
3390
3391    response.send(Ack {})?;
3392
3393    channel_buffer_updated(
3394        session.connection_id,
3395        left_buffer.connections,
3396        &proto::UpdateChannelBufferCollaborators {
3397            channel_id: channel_id.to_proto(),
3398            collaborators: left_buffer.collaborators,
3399        },
3400        &session.peer,
3401    );
3402
3403    Ok(())
3404}
3405
3406fn channel_buffer_updated<T: EnvelopedMessage>(
3407    sender_id: ConnectionId,
3408    collaborators: impl IntoIterator<Item = ConnectionId>,
3409    message: &T,
3410    peer: &Peer,
3411) {
3412    broadcast(Some(sender_id), collaborators, |peer_id| {
3413        peer.send(peer_id, message.clone())
3414    });
3415}
3416
3417fn send_notifications(
3418    connection_pool: &ConnectionPool,
3419    peer: &Peer,
3420    notifications: db::NotificationBatch,
3421) {
3422    for (user_id, notification) in notifications {
3423        for connection_id in connection_pool.user_connection_ids(user_id) {
3424            if let Err(error) = peer.send(
3425                connection_id,
3426                proto::AddNotification {
3427                    notification: Some(notification.clone()),
3428                },
3429            ) {
3430                tracing::error!(
3431                    "failed to send notification to {:?} {}",
3432                    connection_id,
3433                    error
3434                );
3435            }
3436        }
3437    }
3438}
3439
3440/// Send a message to the channel
3441async fn send_channel_message(
3442    request: proto::SendChannelMessage,
3443    response: Response<proto::SendChannelMessage>,
3444    session: Session,
3445) -> Result<()> {
3446    // Validate the message body.
3447    let body = request.body.trim().to_string();
3448    if body.len() > MAX_MESSAGE_LEN {
3449        return Err(anyhow!("message is too long"))?;
3450    }
3451    if body.is_empty() {
3452        return Err(anyhow!("message can't be blank"))?;
3453    }
3454
3455    // TODO: adjust mentions if body is trimmed
3456
3457    let timestamp = OffsetDateTime::now_utc();
3458    let nonce = request
3459        .nonce
3460        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3461
3462    let channel_id = ChannelId::from_proto(request.channel_id);
3463    let CreatedChannelMessage {
3464        message_id,
3465        participant_connection_ids,
3466        notifications,
3467    } = session
3468        .db()
3469        .await
3470        .create_channel_message(
3471            channel_id,
3472            session.user_id(),
3473            &body,
3474            &request.mentions,
3475            timestamp,
3476            nonce.clone().into(),
3477            request.reply_to_message_id.map(MessageId::from_proto),
3478        )
3479        .await?;
3480
3481    let message = proto::ChannelMessage {
3482        sender_id: session.user_id().to_proto(),
3483        id: message_id.to_proto(),
3484        body,
3485        mentions: request.mentions,
3486        timestamp: timestamp.unix_timestamp() as u64,
3487        nonce: Some(nonce),
3488        reply_to_message_id: request.reply_to_message_id,
3489        edited_at: None,
3490    };
3491    broadcast(
3492        Some(session.connection_id),
3493        participant_connection_ids.clone(),
3494        |connection| {
3495            session.peer.send(
3496                connection,
3497                proto::ChannelMessageSent {
3498                    channel_id: channel_id.to_proto(),
3499                    message: Some(message.clone()),
3500                },
3501            )
3502        },
3503    );
3504    response.send(proto::SendChannelMessageResponse {
3505        message: Some(message),
3506    })?;
3507
3508    let pool = &*session.connection_pool().await;
3509    let non_participants =
3510        pool.channel_connection_ids(channel_id)
3511            .filter_map(|(connection_id, _)| {
3512                if participant_connection_ids.contains(&connection_id) {
3513                    None
3514                } else {
3515                    Some(connection_id)
3516                }
3517            });
3518    broadcast(None, non_participants, |peer_id| {
3519        session.peer.send(
3520            peer_id,
3521            proto::UpdateChannels {
3522                latest_channel_message_ids: vec![proto::ChannelMessageId {
3523                    channel_id: channel_id.to_proto(),
3524                    message_id: message_id.to_proto(),
3525                }],
3526                ..Default::default()
3527            },
3528        )
3529    });
3530    send_notifications(pool, &session.peer, notifications);
3531
3532    Ok(())
3533}
3534
3535/// Delete a channel message
3536async fn remove_channel_message(
3537    request: proto::RemoveChannelMessage,
3538    response: Response<proto::RemoveChannelMessage>,
3539    session: Session,
3540) -> Result<()> {
3541    let channel_id = ChannelId::from_proto(request.channel_id);
3542    let message_id = MessageId::from_proto(request.message_id);
3543    let (connection_ids, existing_notification_ids) = session
3544        .db()
3545        .await
3546        .remove_channel_message(channel_id, message_id, session.user_id())
3547        .await?;
3548
3549    broadcast(
3550        Some(session.connection_id),
3551        connection_ids,
3552        move |connection| {
3553            session.peer.send(connection, request.clone())?;
3554
3555            for notification_id in &existing_notification_ids {
3556                session.peer.send(
3557                    connection,
3558                    proto::DeleteNotification {
3559                        notification_id: (*notification_id).to_proto(),
3560                    },
3561                )?;
3562            }
3563
3564            Ok(())
3565        },
3566    );
3567    response.send(proto::Ack {})?;
3568    Ok(())
3569}
3570
3571async fn update_channel_message(
3572    request: proto::UpdateChannelMessage,
3573    response: Response<proto::UpdateChannelMessage>,
3574    session: Session,
3575) -> Result<()> {
3576    let channel_id = ChannelId::from_proto(request.channel_id);
3577    let message_id = MessageId::from_proto(request.message_id);
3578    let updated_at = OffsetDateTime::now_utc();
3579    let UpdatedChannelMessage {
3580        message_id,
3581        participant_connection_ids,
3582        notifications,
3583        reply_to_message_id,
3584        timestamp,
3585        deleted_mention_notification_ids,
3586        updated_mention_notifications,
3587    } = session
3588        .db()
3589        .await
3590        .update_channel_message(
3591            channel_id,
3592            message_id,
3593            session.user_id(),
3594            request.body.as_str(),
3595            &request.mentions,
3596            updated_at,
3597        )
3598        .await?;
3599
3600    let nonce = request
3601        .nonce
3602        .clone()
3603        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3604
3605    let message = proto::ChannelMessage {
3606        sender_id: session.user_id().to_proto(),
3607        id: message_id.to_proto(),
3608        body: request.body.clone(),
3609        mentions: request.mentions.clone(),
3610        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3611        nonce: Some(nonce),
3612        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3613        edited_at: Some(updated_at.unix_timestamp() as u64),
3614    };
3615
3616    response.send(proto::Ack {})?;
3617
3618    let pool = &*session.connection_pool().await;
3619    broadcast(
3620        Some(session.connection_id),
3621        participant_connection_ids,
3622        |connection| {
3623            session.peer.send(
3624                connection,
3625                proto::ChannelMessageUpdate {
3626                    channel_id: channel_id.to_proto(),
3627                    message: Some(message.clone()),
3628                },
3629            )?;
3630
3631            for notification_id in &deleted_mention_notification_ids {
3632                session.peer.send(
3633                    connection,
3634                    proto::DeleteNotification {
3635                        notification_id: (*notification_id).to_proto(),
3636                    },
3637                )?;
3638            }
3639
3640            for notification in &updated_mention_notifications {
3641                session.peer.send(
3642                    connection,
3643                    proto::UpdateNotification {
3644                        notification: Some(notification.clone()),
3645                    },
3646                )?;
3647            }
3648
3649            Ok(())
3650        },
3651    );
3652
3653    send_notifications(pool, &session.peer, notifications);
3654
3655    Ok(())
3656}
3657
3658/// Mark a channel message as read
3659async fn acknowledge_channel_message(
3660    request: proto::AckChannelMessage,
3661    session: Session,
3662) -> Result<()> {
3663    let channel_id = ChannelId::from_proto(request.channel_id);
3664    let message_id = MessageId::from_proto(request.message_id);
3665    let notifications = session
3666        .db()
3667        .await
3668        .observe_channel_message(channel_id, session.user_id(), message_id)
3669        .await?;
3670    send_notifications(
3671        &*session.connection_pool().await,
3672        &session.peer,
3673        notifications,
3674    );
3675    Ok(())
3676}
3677
3678/// Mark a buffer version as synced
3679async fn acknowledge_buffer_version(
3680    request: proto::AckBufferOperation,
3681    session: Session,
3682) -> Result<()> {
3683    let buffer_id = BufferId::from_proto(request.buffer_id);
3684    session
3685        .db()
3686        .await
3687        .observe_buffer_version(
3688            buffer_id,
3689            session.user_id(),
3690            request.epoch as i32,
3691            &request.version,
3692        )
3693        .await?;
3694    Ok(())
3695}
3696
3697async fn count_language_model_tokens(
3698    request: proto::CountLanguageModelTokens,
3699    response: Response<proto::CountLanguageModelTokens>,
3700    session: Session,
3701    config: &Config,
3702) -> Result<()> {
3703    authorize_access_to_legacy_llm_endpoints(&session).await?;
3704
3705    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3706        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3707        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3708    };
3709
3710    session
3711        .app_state
3712        .rate_limiter
3713        .check(&*rate_limit, session.user_id())
3714        .await?;
3715
3716    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3717        Some(proto::LanguageModelProvider::Google) => {
3718            let api_key = config
3719                .google_ai_api_key
3720                .as_ref()
3721                .context("no Google AI API key configured on the server")?;
3722            google_ai::count_tokens(
3723                session.http_client.as_ref(),
3724                google_ai::API_URL,
3725                api_key,
3726                serde_json::from_str(&request.request)?,
3727            )
3728            .await?
3729        }
3730        _ => return Err(anyhow!("unsupported provider"))?,
3731    };
3732
3733    response.send(proto::CountLanguageModelTokensResponse {
3734        token_count: result.total_tokens as u32,
3735    })?;
3736
3737    Ok(())
3738}
3739
3740struct ZedProCountLanguageModelTokensRateLimit;
3741
3742impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3743    fn capacity(&self) -> usize {
3744        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3745            .ok()
3746            .and_then(|v| v.parse().ok())
3747            .unwrap_or(600) // Picked arbitrarily
3748    }
3749
3750    fn refill_duration(&self) -> chrono::Duration {
3751        chrono::Duration::hours(1)
3752    }
3753
3754    fn db_name(&self) -> &'static str {
3755        "zed-pro:count-language-model-tokens"
3756    }
3757}
3758
3759struct FreeCountLanguageModelTokensRateLimit;
3760
3761impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3762    fn capacity(&self) -> usize {
3763        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3764            .ok()
3765            .and_then(|v| v.parse().ok())
3766            .unwrap_or(600 / 10) // Picked arbitrarily
3767    }
3768
3769    fn refill_duration(&self) -> chrono::Duration {
3770        chrono::Duration::hours(1)
3771    }
3772
3773    fn db_name(&self) -> &'static str {
3774        "free:count-language-model-tokens"
3775    }
3776}
3777
3778struct ZedProComputeEmbeddingsRateLimit;
3779
3780impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3781    fn capacity(&self) -> usize {
3782        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3783            .ok()
3784            .and_then(|v| v.parse().ok())
3785            .unwrap_or(5000) // Picked arbitrarily
3786    }
3787
3788    fn refill_duration(&self) -> chrono::Duration {
3789        chrono::Duration::hours(1)
3790    }
3791
3792    fn db_name(&self) -> &'static str {
3793        "zed-pro:compute-embeddings"
3794    }
3795}
3796
3797struct FreeComputeEmbeddingsRateLimit;
3798
3799impl RateLimit for FreeComputeEmbeddingsRateLimit {
3800    fn capacity(&self) -> usize {
3801        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3802            .ok()
3803            .and_then(|v| v.parse().ok())
3804            .unwrap_or(5000 / 10) // Picked arbitrarily
3805    }
3806
3807    fn refill_duration(&self) -> chrono::Duration {
3808        chrono::Duration::hours(1)
3809    }
3810
3811    fn db_name(&self) -> &'static str {
3812        "free:compute-embeddings"
3813    }
3814}
3815
3816async fn compute_embeddings(
3817    request: proto::ComputeEmbeddings,
3818    response: Response<proto::ComputeEmbeddings>,
3819    session: Session,
3820    api_key: Option<Arc<str>>,
3821) -> Result<()> {
3822    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3823    authorize_access_to_legacy_llm_endpoints(&session).await?;
3824
3825    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3826        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3827        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3828    };
3829
3830    session
3831        .app_state
3832        .rate_limiter
3833        .check(&*rate_limit, session.user_id())
3834        .await?;
3835
3836    let embeddings = match request.model.as_str() {
3837        "openai/text-embedding-3-small" => {
3838            open_ai::embed(
3839                session.http_client.as_ref(),
3840                OPEN_AI_API_URL,
3841                &api_key,
3842                OpenAiEmbeddingModel::TextEmbedding3Small,
3843                request.texts.iter().map(|text| text.as_str()),
3844            )
3845            .await?
3846        }
3847        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3848    };
3849
3850    let embeddings = request
3851        .texts
3852        .iter()
3853        .map(|text| {
3854            let mut hasher = sha2::Sha256::new();
3855            hasher.update(text.as_bytes());
3856            let result = hasher.finalize();
3857            result.to_vec()
3858        })
3859        .zip(
3860            embeddings
3861                .data
3862                .into_iter()
3863                .map(|embedding| embedding.embedding),
3864        )
3865        .collect::<HashMap<_, _>>();
3866
3867    let db = session.db().await;
3868    db.save_embeddings(&request.model, &embeddings)
3869        .await
3870        .context("failed to save embeddings")
3871        .trace_err();
3872
3873    response.send(proto::ComputeEmbeddingsResponse {
3874        embeddings: embeddings
3875            .into_iter()
3876            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3877            .collect(),
3878    })?;
3879    Ok(())
3880}
3881
3882async fn get_cached_embeddings(
3883    request: proto::GetCachedEmbeddings,
3884    response: Response<proto::GetCachedEmbeddings>,
3885    session: Session,
3886) -> Result<()> {
3887    authorize_access_to_legacy_llm_endpoints(&session).await?;
3888
3889    let db = session.db().await;
3890    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3891
3892    response.send(proto::GetCachedEmbeddingsResponse {
3893        embeddings: embeddings
3894            .into_iter()
3895            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3896            .collect(),
3897    })?;
3898    Ok(())
3899}
3900
3901/// This is leftover from before the LLM service.
3902///
3903/// The endpoints protected by this check will be moved there eventually.
3904async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3905    if session.is_staff() {
3906        Ok(())
3907    } else {
3908        Err(anyhow!("permission denied"))?
3909    }
3910}
3911
3912/// Get a Supermaven API key for the user
3913async fn get_supermaven_api_key(
3914    _request: proto::GetSupermavenApiKey,
3915    response: Response<proto::GetSupermavenApiKey>,
3916    session: Session,
3917) -> Result<()> {
3918    let user_id: String = session.user_id().to_string();
3919    if !session.is_staff() {
3920        return Err(anyhow!("supermaven not enabled for this account"))?;
3921    }
3922
3923    let email = session
3924        .email()
3925        .ok_or_else(|| anyhow!("user must have an email"))?;
3926
3927    let supermaven_admin_api = session
3928        .supermaven_client
3929        .as_ref()
3930        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3931
3932    let result = supermaven_admin_api
3933        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3934        .await?;
3935
3936    response.send(proto::GetSupermavenApiKeyResponse {
3937        api_key: result.api_key,
3938    })?;
3939
3940    Ok(())
3941}
3942
3943/// Start receiving chat updates for a channel
3944async fn join_channel_chat(
3945    request: proto::JoinChannelChat,
3946    response: Response<proto::JoinChannelChat>,
3947    session: Session,
3948) -> Result<()> {
3949    let channel_id = ChannelId::from_proto(request.channel_id);
3950
3951    let db = session.db().await;
3952    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3953        .await?;
3954    let messages = db
3955        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3956        .await?;
3957    response.send(proto::JoinChannelChatResponse {
3958        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3959        messages,
3960    })?;
3961    Ok(())
3962}
3963
3964/// Stop receiving chat updates for a channel
3965async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3966    let channel_id = ChannelId::from_proto(request.channel_id);
3967    session
3968        .db()
3969        .await
3970        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3971        .await?;
3972    Ok(())
3973}
3974
3975/// Retrieve the chat history for a channel
3976async fn get_channel_messages(
3977    request: proto::GetChannelMessages,
3978    response: Response<proto::GetChannelMessages>,
3979    session: Session,
3980) -> Result<()> {
3981    let channel_id = ChannelId::from_proto(request.channel_id);
3982    let messages = session
3983        .db()
3984        .await
3985        .get_channel_messages(
3986            channel_id,
3987            session.user_id(),
3988            MESSAGE_COUNT_PER_PAGE,
3989            Some(MessageId::from_proto(request.before_message_id)),
3990        )
3991        .await?;
3992    response.send(proto::GetChannelMessagesResponse {
3993        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3994        messages,
3995    })?;
3996    Ok(())
3997}
3998
3999/// Retrieve specific chat messages
4000async fn get_channel_messages_by_id(
4001    request: proto::GetChannelMessagesById,
4002    response: Response<proto::GetChannelMessagesById>,
4003    session: Session,
4004) -> Result<()> {
4005    let message_ids = request
4006        .message_ids
4007        .iter()
4008        .map(|id| MessageId::from_proto(*id))
4009        .collect::<Vec<_>>();
4010    let messages = session
4011        .db()
4012        .await
4013        .get_channel_messages_by_id(session.user_id(), &message_ids)
4014        .await?;
4015    response.send(proto::GetChannelMessagesResponse {
4016        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4017        messages,
4018    })?;
4019    Ok(())
4020}
4021
4022/// Retrieve the current users notifications
4023async fn get_notifications(
4024    request: proto::GetNotifications,
4025    response: Response<proto::GetNotifications>,
4026    session: Session,
4027) -> Result<()> {
4028    let notifications = session
4029        .db()
4030        .await
4031        .get_notifications(
4032            session.user_id(),
4033            NOTIFICATION_COUNT_PER_PAGE,
4034            request.before_id.map(db::NotificationId::from_proto),
4035        )
4036        .await?;
4037    response.send(proto::GetNotificationsResponse {
4038        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4039        notifications,
4040    })?;
4041    Ok(())
4042}
4043
4044/// Mark notifications as read
4045async fn mark_notification_as_read(
4046    request: proto::MarkNotificationRead,
4047    response: Response<proto::MarkNotificationRead>,
4048    session: Session,
4049) -> Result<()> {
4050    let database = &session.db().await;
4051    let notifications = database
4052        .mark_notification_as_read_by_id(
4053            session.user_id(),
4054            NotificationId::from_proto(request.notification_id),
4055        )
4056        .await?;
4057    send_notifications(
4058        &*session.connection_pool().await,
4059        &session.peer,
4060        notifications,
4061    );
4062    response.send(proto::Ack {})?;
4063    Ok(())
4064}
4065
4066/// Get the current users information
4067async fn get_private_user_info(
4068    _request: proto::GetPrivateUserInfo,
4069    response: Response<proto::GetPrivateUserInfo>,
4070    session: Session,
4071) -> Result<()> {
4072    let db = session.db().await;
4073
4074    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4075    let user = db
4076        .get_user_by_id(session.user_id())
4077        .await?
4078        .ok_or_else(|| anyhow!("user not found"))?;
4079    let flags = db.get_user_flags(session.user_id()).await?;
4080
4081    response.send(proto::GetPrivateUserInfoResponse {
4082        metrics_id,
4083        staff: user.admin,
4084        flags,
4085        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4086    })?;
4087    Ok(())
4088}
4089
4090/// Accept the terms of service (tos) on behalf of the current user
4091async fn accept_terms_of_service(
4092    _request: proto::AcceptTermsOfService,
4093    response: Response<proto::AcceptTermsOfService>,
4094    session: Session,
4095) -> Result<()> {
4096    let db = session.db().await;
4097
4098    let accepted_tos_at = Utc::now();
4099    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4100        .await?;
4101
4102    response.send(proto::AcceptTermsOfServiceResponse {
4103        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4104    })?;
4105    Ok(())
4106}
4107
4108/// The minimum account age an account must have in order to use the LLM service.
4109pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4110
4111async fn get_llm_api_token(
4112    _request: proto::GetLlmToken,
4113    response: Response<proto::GetLlmToken>,
4114    session: Session,
4115) -> Result<()> {
4116    let db = session.db().await;
4117
4118    let flags = db.get_user_flags(session.user_id()).await?;
4119    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4120
4121    if !session.is_staff() && !has_language_models_feature_flag {
4122        Err(anyhow!("permission denied"))?
4123    }
4124
4125    let user_id = session.user_id();
4126    let user = db
4127        .get_user_by_id(user_id)
4128        .await?
4129        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4130
4131    if user.accepted_tos_at.is_none() {
4132        Err(anyhow!("terms of service not accepted"))?
4133    }
4134
4135    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4136    let billing_preferences = db.get_billing_preferences(user.id).await?;
4137
4138    let token = LlmTokenClaims::create(
4139        &user,
4140        session.is_staff(),
4141        billing_preferences,
4142        &flags,
4143        has_llm_subscription,
4144        session.current_plan(&db).await?,
4145        session.system_id.clone(),
4146        &session.app_state.config,
4147    )?;
4148    response.send(proto::GetLlmTokenResponse { token })?;
4149    Ok(())
4150}
4151
4152fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4153    let message = match message {
4154        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4155        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4156        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4157        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4158        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4159            code: frame.code.into(),
4160            reason: frame.reason,
4161        })),
4162        // We should never receive a frame while reading the message, according
4163        // to the `tungstenite` maintainers:
4164        //
4165        // > It cannot occur when you read messages from the WebSocket, but it
4166        // > can be used when you want to send the raw frames (e.g. you want to
4167        // > send the frames to the WebSocket without composing the full message first).
4168        // >
4169        // > — https://github.com/snapview/tungstenite-rs/issues/268
4170        TungsteniteMessage::Frame(_) => {
4171            bail!("received an unexpected frame while reading the message")
4172        }
4173    };
4174
4175    Ok(message)
4176}
4177
4178fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4179    match message {
4180        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4181        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4182        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4183        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4184        AxumMessage::Close(frame) => {
4185            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4186                code: frame.code.into(),
4187                reason: frame.reason,
4188            }))
4189        }
4190    }
4191}
4192
4193fn notify_membership_updated(
4194    connection_pool: &mut ConnectionPool,
4195    result: MembershipUpdated,
4196    user_id: UserId,
4197    peer: &Peer,
4198) {
4199    for membership in &result.new_channels.channel_memberships {
4200        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4201    }
4202    for channel_id in &result.removed_channels {
4203        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4204    }
4205
4206    let user_channels_update = proto::UpdateUserChannels {
4207        channel_memberships: result
4208            .new_channels
4209            .channel_memberships
4210            .iter()
4211            .map(|cm| proto::ChannelMembership {
4212                channel_id: cm.channel_id.to_proto(),
4213                role: cm.role.into(),
4214            })
4215            .collect(),
4216        ..Default::default()
4217    };
4218
4219    let mut update = build_channels_update(result.new_channels);
4220    update.delete_channels = result
4221        .removed_channels
4222        .into_iter()
4223        .map(|id| id.to_proto())
4224        .collect();
4225    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4226
4227    for connection_id in connection_pool.user_connection_ids(user_id) {
4228        peer.send(connection_id, user_channels_update.clone())
4229            .trace_err();
4230        peer.send(connection_id, update.clone()).trace_err();
4231    }
4232}
4233
4234fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4235    proto::UpdateUserChannels {
4236        channel_memberships: channels
4237            .channel_memberships
4238            .iter()
4239            .map(|m| proto::ChannelMembership {
4240                channel_id: m.channel_id.to_proto(),
4241                role: m.role.into(),
4242            })
4243            .collect(),
4244        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4245        observed_channel_message_id: channels.observed_channel_messages.clone(),
4246    }
4247}
4248
4249fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4250    let mut update = proto::UpdateChannels::default();
4251
4252    for channel in channels.channels {
4253        update.channels.push(channel.to_proto());
4254    }
4255
4256    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4257    update.latest_channel_message_ids = channels.latest_channel_messages;
4258
4259    for (channel_id, participants) in channels.channel_participants {
4260        update
4261            .channel_participants
4262            .push(proto::ChannelParticipants {
4263                channel_id: channel_id.to_proto(),
4264                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4265            });
4266    }
4267
4268    for channel in channels.invited_channels {
4269        update.channel_invitations.push(channel.to_proto());
4270    }
4271
4272    update
4273}
4274
4275fn build_initial_contacts_update(
4276    contacts: Vec<db::Contact>,
4277    pool: &ConnectionPool,
4278) -> proto::UpdateContacts {
4279    let mut update = proto::UpdateContacts::default();
4280
4281    for contact in contacts {
4282        match contact {
4283            db::Contact::Accepted { user_id, busy } => {
4284                update.contacts.push(contact_for_user(user_id, busy, pool));
4285            }
4286            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4287            db::Contact::Incoming { user_id } => {
4288                update
4289                    .incoming_requests
4290                    .push(proto::IncomingContactRequest {
4291                        requester_id: user_id.to_proto(),
4292                    })
4293            }
4294        }
4295    }
4296
4297    update
4298}
4299
4300fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4301    proto::Contact {
4302        user_id: user_id.to_proto(),
4303        online: pool.is_user_online(user_id),
4304        busy,
4305    }
4306}
4307
4308fn room_updated(room: &proto::Room, peer: &Peer) {
4309    broadcast(
4310        None,
4311        room.participants
4312            .iter()
4313            .filter_map(|participant| Some(participant.peer_id?.into())),
4314        |peer_id| {
4315            peer.send(
4316                peer_id,
4317                proto::RoomUpdated {
4318                    room: Some(room.clone()),
4319                },
4320            )
4321        },
4322    );
4323}
4324
4325fn channel_updated(
4326    channel: &db::channel::Model,
4327    room: &proto::Room,
4328    peer: &Peer,
4329    pool: &ConnectionPool,
4330) {
4331    let participants = room
4332        .participants
4333        .iter()
4334        .map(|p| p.user_id)
4335        .collect::<Vec<_>>();
4336
4337    broadcast(
4338        None,
4339        pool.channel_connection_ids(channel.root_id())
4340            .filter_map(|(channel_id, role)| {
4341                role.can_see_channel(channel.visibility)
4342                    .then_some(channel_id)
4343            }),
4344        |peer_id| {
4345            peer.send(
4346                peer_id,
4347                proto::UpdateChannels {
4348                    channel_participants: vec![proto::ChannelParticipants {
4349                        channel_id: channel.id.to_proto(),
4350                        participant_user_ids: participants.clone(),
4351                    }],
4352                    ..Default::default()
4353                },
4354            )
4355        },
4356    );
4357}
4358
4359async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4360    let db = session.db().await;
4361
4362    let contacts = db.get_contacts(user_id).await?;
4363    let busy = db.is_user_busy(user_id).await?;
4364
4365    let pool = session.connection_pool().await;
4366    let updated_contact = contact_for_user(user_id, busy, &pool);
4367    for contact in contacts {
4368        if let db::Contact::Accepted {
4369            user_id: contact_user_id,
4370            ..
4371        } = contact
4372        {
4373            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4374                session
4375                    .peer
4376                    .send(
4377                        contact_conn_id,
4378                        proto::UpdateContacts {
4379                            contacts: vec![updated_contact.clone()],
4380                            remove_contacts: Default::default(),
4381                            incoming_requests: Default::default(),
4382                            remove_incoming_requests: Default::default(),
4383                            outgoing_requests: Default::default(),
4384                            remove_outgoing_requests: Default::default(),
4385                        },
4386                    )
4387                    .trace_err();
4388            }
4389        }
4390    }
4391    Ok(())
4392}
4393
4394async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4395    let mut contacts_to_update = HashSet::default();
4396
4397    let room_id;
4398    let canceled_calls_to_user_ids;
4399    let livekit_room;
4400    let delete_livekit_room;
4401    let room;
4402    let channel;
4403
4404    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4405        contacts_to_update.insert(session.user_id());
4406
4407        for project in left_room.left_projects.values() {
4408            project_left(project, session);
4409        }
4410
4411        room_id = RoomId::from_proto(left_room.room.id);
4412        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4413        livekit_room = mem::take(&mut left_room.room.livekit_room);
4414        delete_livekit_room = left_room.deleted;
4415        room = mem::take(&mut left_room.room);
4416        channel = mem::take(&mut left_room.channel);
4417
4418        room_updated(&room, &session.peer);
4419    } else {
4420        return Ok(());
4421    }
4422
4423    if let Some(channel) = channel {
4424        channel_updated(
4425            &channel,
4426            &room,
4427            &session.peer,
4428            &*session.connection_pool().await,
4429        );
4430    }
4431
4432    {
4433        let pool = session.connection_pool().await;
4434        for canceled_user_id in canceled_calls_to_user_ids {
4435            for connection_id in pool.user_connection_ids(canceled_user_id) {
4436                session
4437                    .peer
4438                    .send(
4439                        connection_id,
4440                        proto::CallCanceled {
4441                            room_id: room_id.to_proto(),
4442                        },
4443                    )
4444                    .trace_err();
4445            }
4446            contacts_to_update.insert(canceled_user_id);
4447        }
4448    }
4449
4450    for contact_user_id in contacts_to_update {
4451        update_user_contacts(contact_user_id, session).await?;
4452    }
4453
4454    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4455        live_kit
4456            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4457            .await
4458            .trace_err();
4459
4460        if delete_livekit_room {
4461            live_kit.delete_room(livekit_room).await.trace_err();
4462        }
4463    }
4464
4465    Ok(())
4466}
4467
4468async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4469    let left_channel_buffers = session
4470        .db()
4471        .await
4472        .leave_channel_buffers(session.connection_id)
4473        .await?;
4474
4475    for left_buffer in left_channel_buffers {
4476        channel_buffer_updated(
4477            session.connection_id,
4478            left_buffer.connections,
4479            &proto::UpdateChannelBufferCollaborators {
4480                channel_id: left_buffer.channel_id.to_proto(),
4481                collaborators: left_buffer.collaborators,
4482            },
4483            &session.peer,
4484        );
4485    }
4486
4487    Ok(())
4488}
4489
4490fn project_left(project: &db::LeftProject, session: &Session) {
4491    for connection_id in &project.connection_ids {
4492        if project.should_unshare {
4493            session
4494                .peer
4495                .send(
4496                    *connection_id,
4497                    proto::UnshareProject {
4498                        project_id: project.id.to_proto(),
4499                    },
4500                )
4501                .trace_err();
4502        } else {
4503            session
4504                .peer
4505                .send(
4506                    *connection_id,
4507                    proto::RemoveProjectCollaborator {
4508                        project_id: project.id.to_proto(),
4509                        peer_id: Some(session.connection_id.into()),
4510                    },
4511                )
4512                .trace_err();
4513        }
4514    }
4515}
4516
4517pub trait ResultExt {
4518    type Ok;
4519
4520    fn trace_err(self) -> Option<Self::Ok>;
4521}
4522
4523impl<T, E> ResultExt for Result<T, E>
4524where
4525    E: std::fmt::Debug,
4526{
4527    type Ok = T;
4528
4529    #[track_caller]
4530    fn trace_err(self) -> Option<T> {
4531        match self {
4532            Ok(value) => Some(value),
4533            Err(error) => {
4534                tracing::error!("{:?}", error);
4535                None
4536            }
4537        }
4538    }
4539}