rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Config, Error, RateLimit, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use http_client::HttpClient;
  37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
  38use reqwest_client::ReqwestClient;
  39use rpc::proto::split_repository_update;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{MutexGuard, Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 311            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 312            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 313            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 320            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 321            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 322            .add_request_handler(
 323                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 324            )
 325            .add_request_handler(
 326                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 327            )
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 341            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 346            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 351            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 356            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 357            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 358            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_request_handler(update_buffer)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 368            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 369            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 370            .add_request_handler(get_users)
 371            .add_request_handler(fuzzy_search_users)
 372            .add_request_handler(request_contact)
 373            .add_request_handler(remove_contact)
 374            .add_request_handler(respond_to_contact_request)
 375            .add_message_handler(subscribe_to_channels)
 376            .add_request_handler(create_channel)
 377            .add_request_handler(delete_channel)
 378            .add_request_handler(invite_channel_member)
 379            .add_request_handler(remove_channel_member)
 380            .add_request_handler(set_channel_member_role)
 381            .add_request_handler(set_channel_visibility)
 382            .add_request_handler(rename_channel)
 383            .add_request_handler(join_channel_buffer)
 384            .add_request_handler(leave_channel_buffer)
 385            .add_message_handler(update_channel_buffer)
 386            .add_request_handler(rejoin_channel_buffers)
 387            .add_request_handler(get_channel_members)
 388            .add_request_handler(respond_to_channel_invite)
 389            .add_request_handler(join_channel)
 390            .add_request_handler(join_channel_chat)
 391            .add_message_handler(leave_channel_chat)
 392            .add_request_handler(send_channel_message)
 393            .add_request_handler(remove_channel_message)
 394            .add_request_handler(update_channel_message)
 395            .add_request_handler(get_channel_messages)
 396            .add_request_handler(get_channel_messages_by_id)
 397            .add_request_handler(get_notifications)
 398            .add_request_handler(mark_notification_as_read)
 399            .add_request_handler(move_channel)
 400            .add_request_handler(follow)
 401            .add_message_handler(unfollow)
 402            .add_message_handler(update_followers)
 403            .add_request_handler(get_private_user_info)
 404            .add_request_handler(get_llm_api_token)
 405            .add_request_handler(accept_terms_of_service)
 406            .add_message_handler(acknowledge_channel_message)
 407            .add_message_handler(acknowledge_buffer_version)
 408            .add_request_handler(get_supermaven_api_key)
 409            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 410            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 411            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 413            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 415            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 416            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 417            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 418            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 419            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 420            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 421            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 422            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 424            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 425            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 426            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 427            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 428            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 429            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 430            .add_message_handler(update_context)
 431            .add_request_handler({
 432                let app_state = app_state.clone();
 433                move |request, response, session| {
 434                    let app_state = app_state.clone();
 435                    async move {
 436                        count_language_model_tokens(request, response, session, &app_state.config)
 437                            .await
 438                    }
 439                }
 440            })
 441            .add_request_handler(get_cached_embeddings)
 442            .add_request_handler({
 443                let app_state = app_state.clone();
 444                move |request, response, session| {
 445                    compute_embeddings(
 446                        request,
 447                        response,
 448                        session,
 449                        app_state.config.openai_api_key.clone(),
 450                    )
 451                }
 452            });
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471                if let Some((room_ids, channel_ids)) = app_state
 472                    .db
 473                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 474                    .await
 475                    .trace_err()
 476                {
 477                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 478                    tracing::info!(
 479                        stale_channel_buffer_count = channel_ids.len(),
 480                        "retrieved stale channel buffers"
 481                    );
 482
 483                    for channel_id in channel_ids {
 484                        if let Some(refreshed_channel_buffer) = app_state
 485                            .db
 486                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 487                            .await
 488                            .trace_err()
 489                        {
 490                            for connection_id in refreshed_channel_buffer.connection_ids {
 491                                peer.send(
 492                                    connection_id,
 493                                    proto::UpdateChannelBufferCollaborators {
 494                                        channel_id: channel_id.to_proto(),
 495                                        collaborators: refreshed_channel_buffer
 496                                            .collaborators
 497                                            .clone(),
 498                                    },
 499                                )
 500                                .trace_err();
 501                            }
 502                        }
 503                    }
 504
 505                    for room_id in room_ids {
 506                        let mut contacts_to_update = HashSet::default();
 507                        let mut canceled_calls_to_user_ids = Vec::new();
 508                        let mut livekit_room = String::new();
 509                        let mut delete_livekit_room = false;
 510
 511                        if let Some(mut refreshed_room) = app_state
 512                            .db
 513                            .clear_stale_room_participants(room_id, server_id)
 514                            .await
 515                            .trace_err()
 516                        {
 517                            tracing::info!(
 518                                room_id = room_id.0,
 519                                new_participant_count = refreshed_room.room.participants.len(),
 520                                "refreshed room"
 521                            );
 522                            room_updated(&refreshed_room.room, &peer);
 523                            if let Some(channel) = refreshed_room.channel.as_ref() {
 524                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 525                            }
 526                            contacts_to_update
 527                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 528                            contacts_to_update
 529                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 530                            canceled_calls_to_user_ids =
 531                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 532                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 533                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 534                        }
 535
 536                        {
 537                            let pool = pool.lock();
 538                            for canceled_user_id in canceled_calls_to_user_ids {
 539                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 540                                    peer.send(
 541                                        connection_id,
 542                                        proto::CallCanceled {
 543                                            room_id: room_id.to_proto(),
 544                                        },
 545                                    )
 546                                    .trace_err();
 547                                }
 548                            }
 549                        }
 550
 551                        for user_id in contacts_to_update {
 552                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 553                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 554                            if let Some((busy, contacts)) = busy.zip(contacts) {
 555                                let pool = pool.lock();
 556                                let updated_contact = contact_for_user(user_id, busy, &pool);
 557                                for contact in contacts {
 558                                    if let db::Contact::Accepted {
 559                                        user_id: contact_user_id,
 560                                        ..
 561                                    } = contact
 562                                    {
 563                                        for contact_conn_id in
 564                                            pool.user_connection_ids(contact_user_id)
 565                                        {
 566                                            peer.send(
 567                                                contact_conn_id,
 568                                                proto::UpdateContacts {
 569                                                    contacts: vec![updated_contact.clone()],
 570                                                    remove_contacts: Default::default(),
 571                                                    incoming_requests: Default::default(),
 572                                                    remove_incoming_requests: Default::default(),
 573                                                    outgoing_requests: Default::default(),
 574                                                    remove_outgoing_requests: Default::default(),
 575                                                },
 576                                            )
 577                                            .trace_err();
 578                                        }
 579                                    }
 580                                }
 581                            }
 582                        }
 583
 584                        if let Some(live_kit) = livekit_client.as_ref() {
 585                            if delete_livekit_room {
 586                                live_kit.delete_room(livekit_room).await.trace_err();
 587                            }
 588                        }
 589                    }
 590                }
 591
 592                app_state
 593                    .db
 594                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 595                    .await
 596                    .trace_err();
 597            }
 598            .instrument(span),
 599        );
 600        Ok(())
 601    }
 602
 603    pub fn teardown(&self) {
 604        self.peer.teardown();
 605        self.connection_pool.lock().reset();
 606        let _ = self.teardown.send(true);
 607    }
 608
 609    #[cfg(test)]
 610    pub fn reset(&self, id: ServerId) {
 611        self.teardown();
 612        *self.id.lock() = id;
 613        self.peer.reset(id.0 as u32);
 614        let _ = self.teardown.send(false);
 615    }
 616
 617    #[cfg(test)]
 618    pub fn id(&self) -> ServerId {
 619        *self.id.lock()
 620    }
 621
 622    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 623    where
 624        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 625        Fut: 'static + Send + Future<Output = Result<()>>,
 626        M: EnvelopedMessage,
 627    {
 628        let prev_handler = self.handlers.insert(
 629            TypeId::of::<M>(),
 630            Box::new(move |envelope, session| {
 631                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 632                let received_at = envelope.received_at;
 633                tracing::info!("message received");
 634                let start_time = Instant::now();
 635                let future = (handler)(*envelope, session);
 636                async move {
 637                    let result = future.await;
 638                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 639                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 640                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 641                    let payload_type = M::NAME;
 642
 643                    match result {
 644                        Err(error) => {
 645                            tracing::error!(
 646                                ?error,
 647                                total_duration_ms,
 648                                processing_duration_ms,
 649                                queue_duration_ms,
 650                                payload_type,
 651                                "error handling message"
 652                            )
 653                        }
 654                        Ok(()) => tracing::info!(
 655                            total_duration_ms,
 656                            processing_duration_ms,
 657                            queue_duration_ms,
 658                            "finished handling message"
 659                        ),
 660                    }
 661                }
 662                .boxed()
 663            }),
 664        );
 665        if prev_handler.is_some() {
 666            panic!("registered a handler for the same message twice");
 667        }
 668        self
 669    }
 670
 671    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 674        Fut: 'static + Send + Future<Output = Result<()>>,
 675        M: EnvelopedMessage,
 676    {
 677        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 678        self
 679    }
 680
 681    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 682    where
 683        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 684        Fut: Send + Future<Output = Result<()>>,
 685        M: RequestMessage,
 686    {
 687        let handler = Arc::new(handler);
 688        self.add_handler(move |envelope, session| {
 689            let receipt = envelope.receipt();
 690            let handler = handler.clone();
 691            async move {
 692                let peer = session.peer.clone();
 693                let responded = Arc::new(AtomicBool::default());
 694                let response = Response {
 695                    peer: peer.clone(),
 696                    responded: responded.clone(),
 697                    receipt,
 698                };
 699                match (handler)(envelope.payload, response, session).await {
 700                    Ok(()) => {
 701                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 702                            Ok(())
 703                        } else {
 704                            Err(anyhow!("handler did not send a response"))?
 705                        }
 706                    }
 707                    Err(error) => {
 708                        let proto_err = match &error {
 709                            Error::Internal(err) => err.to_proto(),
 710                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 711                        };
 712                        peer.respond_with_error(receipt, proto_err)?;
 713                        Err(error)
 714                    }
 715                }
 716            }
 717        })
 718    }
 719
 720    pub fn handle_connection(
 721        self: &Arc<Self>,
 722        connection: Connection,
 723        address: String,
 724        principal: Principal,
 725        zed_version: ZedVersion,
 726        geoip_country_code: Option<String>,
 727        system_id: Option<String>,
 728        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 729        executor: Executor,
 730    ) -> impl Future<Output = ()> + use<> {
 731        let this = self.clone();
 732        let span = info_span!("handle connection", %address,
 733            connection_id=field::Empty,
 734            user_id=field::Empty,
 735            login=field::Empty,
 736            impersonator=field::Empty,
 737            geoip_country_code=field::Empty
 738        );
 739        principal.update_span(&span);
 740        if let Some(country_code) = geoip_country_code.as_ref() {
 741            span.record("geoip_country_code", country_code);
 742        }
 743
 744        let mut teardown = self.teardown.subscribe();
 745        async move {
 746            if *teardown.borrow() {
 747                tracing::error!("server is tearing down");
 748                return
 749            }
 750            let (connection_id, handle_io, mut incoming_rx) = this
 751                .peer
 752                .add_connection(connection, {
 753                    let executor = executor.clone();
 754                    move |duration| executor.sleep(duration)
 755                });
 756            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 757
 758            tracing::info!("connection opened");
 759
 760            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 761            let http_client = match ReqwestClient::user_agent(&user_agent) {
 762                Ok(http_client) => Arc::new(http_client),
 763                Err(error) => {
 764                    tracing::error!(?error, "failed to create HTTP client");
 765                    return;
 766                }
 767            };
 768
 769            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 770                    supermaven_admin_api_key.to_string(),
 771                    http_client.clone(),
 772                )));
 773
 774            let session = Session {
 775                principal: principal.clone(),
 776                connection_id,
 777                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 778                peer: this.peer.clone(),
 779                connection_pool: this.connection_pool.clone(),
 780                app_state: this.app_state.clone(),
 781                http_client,
 782                geoip_country_code,
 783                system_id,
 784                _executor: executor.clone(),
 785                supermaven_client,
 786            };
 787
 788            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 789                tracing::error!(?error, "failed to send initial client update");
 790                return;
 791            }
 792
 793            let handle_io = handle_io.fuse();
 794            futures::pin_mut!(handle_io);
 795
 796            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 797            // This prevents deadlocks when e.g., client A performs a request to client B and
 798            // client B performs a request to client A. If both clients stop processing further
 799            // messages until their respective request completes, they won't have a chance to
 800            // respond to the other client's request and cause a deadlock.
 801            //
 802            // This arrangement ensures we will attempt to process earlier messages first, but fall
 803            // back to processing messages arrived later in the spirit of making progress.
 804            let mut foreground_message_handlers = FuturesUnordered::new();
 805            let concurrent_handlers = Arc::new(Semaphore::new(256));
 806            loop {
 807                let next_message = async {
 808                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 809                    let message = incoming_rx.next().await;
 810                    (permit, message)
 811                }.fuse();
 812                futures::pin_mut!(next_message);
 813                futures::select_biased! {
 814                    _ = teardown.changed().fuse() => return,
 815                    result = handle_io => {
 816                        if let Err(error) = result {
 817                            tracing::error!(?error, "error handling I/O");
 818                        }
 819                        break;
 820                    }
 821                    _ = foreground_message_handlers.next() => {}
 822                    next_message = next_message => {
 823                        let (permit, message) = next_message;
 824                        if let Some(message) = message {
 825                            let type_name = message.payload_type_name();
 826                            // note: we copy all the fields from the parent span so we can query them in the logs.
 827                            // (https://github.com/tokio-rs/tracing/issues/2670).
 828                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 829                                user_id=field::Empty,
 830                                login=field::Empty,
 831                                impersonator=field::Empty,
 832                            );
 833                            principal.update_span(&span);
 834                            let span_enter = span.enter();
 835                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 836                                let is_background = message.is_background();
 837                                let handle_message = (handler)(message, session.clone());
 838                                drop(span_enter);
 839
 840                                let handle_message = async move {
 841                                    handle_message.await;
 842                                    drop(permit);
 843                                }.instrument(span);
 844                                if is_background {
 845                                    executor.spawn_detached(handle_message);
 846                                } else {
 847                                    foreground_message_handlers.push(handle_message);
 848                                }
 849                            } else {
 850                                tracing::error!("no message handler");
 851                            }
 852                        } else {
 853                            tracing::info!("connection closed");
 854                            break;
 855                        }
 856                    }
 857                }
 858            }
 859
 860            drop(foreground_message_handlers);
 861            tracing::info!("signing out");
 862            if let Err(error) = connection_lost(session, teardown, executor).await {
 863                tracing::error!(?error, "error signing out");
 864            }
 865
 866        }.instrument(span)
 867    }
 868
 869    async fn send_initial_client_update(
 870        &self,
 871        connection_id: ConnectionId,
 872        principal: &Principal,
 873        zed_version: ZedVersion,
 874        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 875        session: &Session,
 876    ) -> Result<()> {
 877        self.peer.send(
 878            connection_id,
 879            proto::Hello {
 880                peer_id: Some(connection_id.into()),
 881            },
 882        )?;
 883        tracing::info!("sent hello message");
 884        if let Some(send_connection_id) = send_connection_id.take() {
 885            let _ = send_connection_id.send(connection_id);
 886        }
 887
 888        match principal {
 889            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 890                if !user.connected_once {
 891                    self.peer.send(connection_id, proto::ShowContacts {})?;
 892                    self.app_state
 893                        .db
 894                        .set_user_connected_once(user.id, true)
 895                        .await?;
 896                }
 897
 898                update_user_plan(user.id, session).await?;
 899
 900                let contacts = self.app_state.db.get_contacts(user.id).await?;
 901
 902                {
 903                    let mut pool = self.connection_pool.lock();
 904                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 905                    self.peer.send(
 906                        connection_id,
 907                        build_initial_contacts_update(contacts, &pool),
 908                    )?;
 909                }
 910
 911                if should_auto_subscribe_to_channels(zed_version) {
 912                    subscribe_user_to_channels(user.id, session).await?;
 913                }
 914
 915                if let Some(incoming_call) =
 916                    self.app_state.db.incoming_call_for_user(user.id).await?
 917                {
 918                    self.peer.send(connection_id, incoming_call)?;
 919                }
 920
 921                update_user_contacts(user.id, session).await?;
 922            }
 923        }
 924
 925        Ok(())
 926    }
 927
 928    pub async fn invite_code_redeemed(
 929        self: &Arc<Self>,
 930        inviter_id: UserId,
 931        invitee_id: UserId,
 932    ) -> Result<()> {
 933        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 934            if let Some(code) = &user.invite_code {
 935                let pool = self.connection_pool.lock();
 936                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 937                for connection_id in pool.user_connection_ids(inviter_id) {
 938                    self.peer.send(
 939                        connection_id,
 940                        proto::UpdateContacts {
 941                            contacts: vec![invitee_contact.clone()],
 942                            ..Default::default()
 943                        },
 944                    )?;
 945                    self.peer.send(
 946                        connection_id,
 947                        proto::UpdateInviteInfo {
 948                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 949                            count: user.invite_count as u32,
 950                        },
 951                    )?;
 952                }
 953            }
 954        }
 955        Ok(())
 956    }
 957
 958    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 959        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 960            if let Some(invite_code) = &user.invite_code {
 961                let pool = self.connection_pool.lock();
 962                for connection_id in pool.user_connection_ids(user_id) {
 963                    self.peer.send(
 964                        connection_id,
 965                        proto::UpdateInviteInfo {
 966                            url: format!(
 967                                "{}{}",
 968                                self.app_state.config.invite_link_prefix, invite_code
 969                            ),
 970                            count: user.invite_count as u32,
 971                        },
 972                    )?;
 973                }
 974            }
 975        }
 976        Ok(())
 977    }
 978
 979    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 980        let pool = self.connection_pool.lock();
 981        for connection_id in pool.user_connection_ids(user_id) {
 982            self.peer
 983                .send(connection_id, proto::RefreshLlmToken {})
 984                .trace_err();
 985        }
 986    }
 987
 988    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 989        ServerSnapshot {
 990            connection_pool: ConnectionPoolGuard {
 991                guard: self.connection_pool.lock(),
 992                _not_send: PhantomData,
 993            },
 994            peer: &self.peer,
 995        }
 996    }
 997}
 998
 999impl Deref for ConnectionPoolGuard<'_> {
1000    type Target = ConnectionPool;
1001
1002    fn deref(&self) -> &Self::Target {
1003        &self.guard
1004    }
1005}
1006
1007impl DerefMut for ConnectionPoolGuard<'_> {
1008    fn deref_mut(&mut self) -> &mut Self::Target {
1009        &mut self.guard
1010    }
1011}
1012
1013impl Drop for ConnectionPoolGuard<'_> {
1014    fn drop(&mut self) {
1015        #[cfg(test)]
1016        self.check_invariants();
1017    }
1018}
1019
1020fn broadcast<F>(
1021    sender_id: Option<ConnectionId>,
1022    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1023    mut f: F,
1024) where
1025    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1026{
1027    for receiver_id in receiver_ids {
1028        if Some(receiver_id) != sender_id {
1029            if let Err(error) = f(receiver_id) {
1030                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1031            }
1032        }
1033    }
1034}
1035
1036pub struct ProtocolVersion(u32);
1037
1038impl Header for ProtocolVersion {
1039    fn name() -> &'static HeaderName {
1040        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1042    }
1043
1044    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045    where
1046        Self: Sized,
1047        I: Iterator<Item = &'i axum::http::HeaderValue>,
1048    {
1049        let version = values
1050            .next()
1051            .ok_or_else(axum::headers::Error::invalid)?
1052            .to_str()
1053            .map_err(|_| axum::headers::Error::invalid())?
1054            .parse()
1055            .map_err(|_| axum::headers::Error::invalid())?;
1056        Ok(Self(version))
1057    }
1058
1059    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060        values.extend([self.0.to_string().parse().unwrap()]);
1061    }
1062}
1063
1064pub struct AppVersionHeader(SemanticVersion);
1065impl Header for AppVersionHeader {
1066    fn name() -> &'static HeaderName {
1067        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1068        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1069    }
1070
1071    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1072    where
1073        Self: Sized,
1074        I: Iterator<Item = &'i axum::http::HeaderValue>,
1075    {
1076        let version = values
1077            .next()
1078            .ok_or_else(axum::headers::Error::invalid)?
1079            .to_str()
1080            .map_err(|_| axum::headers::Error::invalid())?
1081            .parse()
1082            .map_err(|_| axum::headers::Error::invalid())?;
1083        Ok(Self(version))
1084    }
1085
1086    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1087        values.extend([self.0.to_string().parse().unwrap()]);
1088    }
1089}
1090
1091pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1092    Router::new()
1093        .route("/rpc", get(handle_websocket_request))
1094        .layer(
1095            ServiceBuilder::new()
1096                .layer(Extension(server.app_state.clone()))
1097                .layer(middleware::from_fn(auth::validate_header)),
1098        )
1099        .route("/metrics", get(handle_metrics))
1100        .layer(Extension(server))
1101}
1102
1103pub async fn handle_websocket_request(
1104    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1105    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1106    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1107    Extension(server): Extension<Arc<Server>>,
1108    Extension(principal): Extension<Principal>,
1109    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1110    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1111    ws: WebSocketUpgrade,
1112) -> axum::response::Response {
1113    if protocol_version != rpc::PROTOCOL_VERSION {
1114        return (
1115            StatusCode::UPGRADE_REQUIRED,
1116            "client must be upgraded".to_string(),
1117        )
1118            .into_response();
1119    }
1120
1121    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1122        return (
1123            StatusCode::UPGRADE_REQUIRED,
1124            "no version header found".to_string(),
1125        )
1126            .into_response();
1127    };
1128
1129    if !version.can_collaborate() {
1130        return (
1131            StatusCode::UPGRADE_REQUIRED,
1132            "client must be upgraded".to_string(),
1133        )
1134            .into_response();
1135    }
1136
1137    let socket_address = socket_address.to_string();
1138    ws.on_upgrade(move |socket| {
1139        let socket = socket
1140            .map_ok(to_tungstenite_message)
1141            .err_into()
1142            .with(|message| async move { to_axum_message(message) });
1143        let connection = Connection::new(Box::pin(socket));
1144        async move {
1145            server
1146                .handle_connection(
1147                    connection,
1148                    socket_address,
1149                    principal,
1150                    version,
1151                    country_code_header.map(|header| header.to_string()),
1152                    system_id_header.map(|header| header.to_string()),
1153                    None,
1154                    Executor::Production,
1155                )
1156                .await;
1157        }
1158    })
1159}
1160
1161pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1162    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163    let connections_metric = CONNECTIONS_METRIC
1164        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1165
1166    let connections = server
1167        .connection_pool
1168        .lock()
1169        .connections()
1170        .filter(|connection| !connection.admin)
1171        .count();
1172    connections_metric.set(connections as _);
1173
1174    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1176        register_int_gauge!(
1177            "shared_projects",
1178            "number of open projects with one or more guests"
1179        )
1180        .unwrap()
1181    });
1182
1183    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1184    shared_projects_metric.set(shared_projects as _);
1185
1186    let encoder = prometheus::TextEncoder::new();
1187    let metric_families = prometheus::gather();
1188    let encoded_metrics = encoder
1189        .encode_to_string(&metric_families)
1190        .map_err(|err| anyhow!("{}", err))?;
1191    Ok(encoded_metrics)
1192}
1193
1194#[instrument(err, skip(executor))]
1195async fn connection_lost(
1196    session: Session,
1197    mut teardown: watch::Receiver<bool>,
1198    executor: Executor,
1199) -> Result<()> {
1200    session.peer.disconnect(session.connection_id);
1201    session
1202        .connection_pool()
1203        .await
1204        .remove_connection(session.connection_id)?;
1205
1206    session
1207        .db()
1208        .await
1209        .connection_lost(session.connection_id)
1210        .await
1211        .trace_err();
1212
1213    futures::select_biased! {
1214        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1215
1216            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1217            leave_room_for_session(&session, session.connection_id).await.trace_err();
1218            leave_channel_buffers_for_session(&session)
1219                .await
1220                .trace_err();
1221
1222            if !session
1223                .connection_pool()
1224                .await
1225                .is_user_online(session.user_id())
1226            {
1227                let db = session.db().await;
1228                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1229                    room_updated(&room, &session.peer);
1230                }
1231            }
1232
1233            update_user_contacts(session.user_id(), &session).await?;
1234        },
1235        _ = teardown.changed().fuse() => {}
1236    }
1237
1238    Ok(())
1239}
1240
1241/// Acknowledges a ping from a client, used to keep the connection alive.
1242async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1243    response.send(proto::Ack {})?;
1244    Ok(())
1245}
1246
1247/// Creates a new room for calling (outside of channels)
1248async fn create_room(
1249    _request: proto::CreateRoom,
1250    response: Response<proto::CreateRoom>,
1251    session: Session,
1252) -> Result<()> {
1253    let livekit_room = nanoid::nanoid!(30);
1254
1255    let live_kit_connection_info = util::maybe!(async {
1256        let live_kit = session.app_state.livekit_client.as_ref();
1257        let live_kit = live_kit?;
1258        let user_id = session.user_id().to_string();
1259
1260        let token = live_kit
1261            .room_token(&livekit_room, &user_id.to_string())
1262            .trace_err()?;
1263
1264        Some(proto::LiveKitConnectionInfo {
1265            server_url: live_kit.url().into(),
1266            token,
1267            can_publish: true,
1268        })
1269    })
1270    .await;
1271
1272    let room = session
1273        .db()
1274        .await
1275        .create_room(session.user_id(), session.connection_id, &livekit_room)
1276        .await?;
1277
1278    response.send(proto::CreateRoomResponse {
1279        room: Some(room.clone()),
1280        live_kit_connection_info,
1281    })?;
1282
1283    update_user_contacts(session.user_id(), &session).await?;
1284    Ok(())
1285}
1286
1287/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1288async fn join_room(
1289    request: proto::JoinRoom,
1290    response: Response<proto::JoinRoom>,
1291    session: Session,
1292) -> Result<()> {
1293    let room_id = RoomId::from_proto(request.id);
1294
1295    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1296
1297    if let Some(channel_id) = channel_id {
1298        return join_channel_internal(channel_id, Box::new(response), session).await;
1299    }
1300
1301    let joined_room = {
1302        let room = session
1303            .db()
1304            .await
1305            .join_room(room_id, session.user_id(), session.connection_id)
1306            .await?;
1307        room_updated(&room.room, &session.peer);
1308        room.into_inner()
1309    };
1310
1311    for connection_id in session
1312        .connection_pool()
1313        .await
1314        .user_connection_ids(session.user_id())
1315    {
1316        session
1317            .peer
1318            .send(
1319                connection_id,
1320                proto::CallCanceled {
1321                    room_id: room_id.to_proto(),
1322                },
1323            )
1324            .trace_err();
1325    }
1326
1327    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1328    {
1329        live_kit
1330            .room_token(
1331                &joined_room.room.livekit_room,
1332                &session.user_id().to_string(),
1333            )
1334            .trace_err()
1335            .map(|token| proto::LiveKitConnectionInfo {
1336                server_url: live_kit.url().into(),
1337                token,
1338                can_publish: true,
1339            })
1340    } else {
1341        None
1342    };
1343
1344    response.send(proto::JoinRoomResponse {
1345        room: Some(joined_room.room),
1346        channel_id: None,
1347        live_kit_connection_info,
1348    })?;
1349
1350    update_user_contacts(session.user_id(), &session).await?;
1351    Ok(())
1352}
1353
1354/// Rejoin room is used to reconnect to a room after connection errors.
1355async fn rejoin_room(
1356    request: proto::RejoinRoom,
1357    response: Response<proto::RejoinRoom>,
1358    session: Session,
1359) -> Result<()> {
1360    let room;
1361    let channel;
1362    {
1363        let mut rejoined_room = session
1364            .db()
1365            .await
1366            .rejoin_room(request, session.user_id(), session.connection_id)
1367            .await?;
1368
1369        response.send(proto::RejoinRoomResponse {
1370            room: Some(rejoined_room.room.clone()),
1371            reshared_projects: rejoined_room
1372                .reshared_projects
1373                .iter()
1374                .map(|project| proto::ResharedProject {
1375                    id: project.id.to_proto(),
1376                    collaborators: project
1377                        .collaborators
1378                        .iter()
1379                        .map(|collaborator| collaborator.to_proto())
1380                        .collect(),
1381                })
1382                .collect(),
1383            rejoined_projects: rejoined_room
1384                .rejoined_projects
1385                .iter()
1386                .map(|rejoined_project| rejoined_project.to_proto())
1387                .collect(),
1388        })?;
1389        room_updated(&rejoined_room.room, &session.peer);
1390
1391        for project in &rejoined_room.reshared_projects {
1392            for collaborator in &project.collaborators {
1393                session
1394                    .peer
1395                    .send(
1396                        collaborator.connection_id,
1397                        proto::UpdateProjectCollaborator {
1398                            project_id: project.id.to_proto(),
1399                            old_peer_id: Some(project.old_connection_id.into()),
1400                            new_peer_id: Some(session.connection_id.into()),
1401                        },
1402                    )
1403                    .trace_err();
1404            }
1405
1406            broadcast(
1407                Some(session.connection_id),
1408                project
1409                    .collaborators
1410                    .iter()
1411                    .map(|collaborator| collaborator.connection_id),
1412                |connection_id| {
1413                    session.peer.forward_send(
1414                        session.connection_id,
1415                        connection_id,
1416                        proto::UpdateProject {
1417                            project_id: project.id.to_proto(),
1418                            worktrees: project.worktrees.clone(),
1419                        },
1420                    )
1421                },
1422            );
1423        }
1424
1425        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1426
1427        let rejoined_room = rejoined_room.into_inner();
1428
1429        room = rejoined_room.room;
1430        channel = rejoined_room.channel;
1431    }
1432
1433    if let Some(channel) = channel {
1434        channel_updated(
1435            &channel,
1436            &room,
1437            &session.peer,
1438            &*session.connection_pool().await,
1439        );
1440    }
1441
1442    update_user_contacts(session.user_id(), &session).await?;
1443    Ok(())
1444}
1445
1446fn notify_rejoined_projects(
1447    rejoined_projects: &mut Vec<RejoinedProject>,
1448    session: &Session,
1449) -> Result<()> {
1450    for project in rejoined_projects.iter() {
1451        for collaborator in &project.collaborators {
1452            session
1453                .peer
1454                .send(
1455                    collaborator.connection_id,
1456                    proto::UpdateProjectCollaborator {
1457                        project_id: project.id.to_proto(),
1458                        old_peer_id: Some(project.old_connection_id.into()),
1459                        new_peer_id: Some(session.connection_id.into()),
1460                    },
1461                )
1462                .trace_err();
1463        }
1464    }
1465
1466    for project in rejoined_projects {
1467        for worktree in mem::take(&mut project.worktrees) {
1468            // Stream this worktree's entries.
1469            let message = proto::UpdateWorktree {
1470                project_id: project.id.to_proto(),
1471                worktree_id: worktree.id,
1472                abs_path: worktree.abs_path.clone(),
1473                root_name: worktree.root_name,
1474                updated_entries: worktree.updated_entries,
1475                removed_entries: worktree.removed_entries,
1476                scan_id: worktree.scan_id,
1477                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1478                updated_repositories: worktree.updated_repositories,
1479                removed_repositories: worktree.removed_repositories,
1480            };
1481            for update in proto::split_worktree_update(message) {
1482                session.peer.send(session.connection_id, update)?;
1483            }
1484
1485            // Stream this worktree's diagnostics.
1486            for summary in worktree.diagnostic_summaries {
1487                session.peer.send(
1488                    session.connection_id,
1489                    proto::UpdateDiagnosticSummary {
1490                        project_id: project.id.to_proto(),
1491                        worktree_id: worktree.id,
1492                        summary: Some(summary),
1493                    },
1494                )?;
1495            }
1496
1497            for settings_file in worktree.settings_files {
1498                session.peer.send(
1499                    session.connection_id,
1500                    proto::UpdateWorktreeSettings {
1501                        project_id: project.id.to_proto(),
1502                        worktree_id: worktree.id,
1503                        path: settings_file.path,
1504                        content: Some(settings_file.content),
1505                        kind: Some(settings_file.kind.to_proto().into()),
1506                    },
1507                )?;
1508            }
1509        }
1510
1511        for repository in mem::take(&mut project.updated_repositories) {
1512            for update in split_repository_update(repository) {
1513                session.peer.send(session.connection_id, update)?;
1514            }
1515        }
1516
1517        for id in mem::take(&mut project.removed_repositories) {
1518            session.peer.send(
1519                session.connection_id,
1520                proto::RemoveRepository {
1521                    project_id: project.id.to_proto(),
1522                    id,
1523                },
1524            )?;
1525        }
1526    }
1527
1528    Ok(())
1529}
1530
1531/// leave room disconnects from the room.
1532async fn leave_room(
1533    _: proto::LeaveRoom,
1534    response: Response<proto::LeaveRoom>,
1535    session: Session,
1536) -> Result<()> {
1537    leave_room_for_session(&session, session.connection_id).await?;
1538    response.send(proto::Ack {})?;
1539    Ok(())
1540}
1541
1542/// Updates the permissions of someone else in the room.
1543async fn set_room_participant_role(
1544    request: proto::SetRoomParticipantRole,
1545    response: Response<proto::SetRoomParticipantRole>,
1546    session: Session,
1547) -> Result<()> {
1548    let user_id = UserId::from_proto(request.user_id);
1549    let role = ChannelRole::from(request.role());
1550
1551    let (livekit_room, can_publish) = {
1552        let room = session
1553            .db()
1554            .await
1555            .set_room_participant_role(
1556                session.user_id(),
1557                RoomId::from_proto(request.room_id),
1558                user_id,
1559                role,
1560            )
1561            .await?;
1562
1563        let livekit_room = room.livekit_room.clone();
1564        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1565        room_updated(&room, &session.peer);
1566        (livekit_room, can_publish)
1567    };
1568
1569    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1570        live_kit
1571            .update_participant(
1572                livekit_room.clone(),
1573                request.user_id.to_string(),
1574                livekit_api::proto::ParticipantPermission {
1575                    can_subscribe: true,
1576                    can_publish,
1577                    can_publish_data: can_publish,
1578                    hidden: false,
1579                    recorder: false,
1580                },
1581            )
1582            .await
1583            .trace_err();
1584    }
1585
1586    response.send(proto::Ack {})?;
1587    Ok(())
1588}
1589
1590/// Call someone else into the current room
1591async fn call(
1592    request: proto::Call,
1593    response: Response<proto::Call>,
1594    session: Session,
1595) -> Result<()> {
1596    let room_id = RoomId::from_proto(request.room_id);
1597    let calling_user_id = session.user_id();
1598    let calling_connection_id = session.connection_id;
1599    let called_user_id = UserId::from_proto(request.called_user_id);
1600    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1601    if !session
1602        .db()
1603        .await
1604        .has_contact(calling_user_id, called_user_id)
1605        .await?
1606    {
1607        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1608    }
1609
1610    let incoming_call = {
1611        let (room, incoming_call) = &mut *session
1612            .db()
1613            .await
1614            .call(
1615                room_id,
1616                calling_user_id,
1617                calling_connection_id,
1618                called_user_id,
1619                initial_project_id,
1620            )
1621            .await?;
1622        room_updated(room, &session.peer);
1623        mem::take(incoming_call)
1624    };
1625    update_user_contacts(called_user_id, &session).await?;
1626
1627    let mut calls = session
1628        .connection_pool()
1629        .await
1630        .user_connection_ids(called_user_id)
1631        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1632        .collect::<FuturesUnordered<_>>();
1633
1634    while let Some(call_response) = calls.next().await {
1635        match call_response.as_ref() {
1636            Ok(_) => {
1637                response.send(proto::Ack {})?;
1638                return Ok(());
1639            }
1640            Err(_) => {
1641                call_response.trace_err();
1642            }
1643        }
1644    }
1645
1646    {
1647        let room = session
1648            .db()
1649            .await
1650            .call_failed(room_id, called_user_id)
1651            .await?;
1652        room_updated(&room, &session.peer);
1653    }
1654    update_user_contacts(called_user_id, &session).await?;
1655
1656    Err(anyhow!("failed to ring user"))?
1657}
1658
1659/// Cancel an outgoing call.
1660async fn cancel_call(
1661    request: proto::CancelCall,
1662    response: Response<proto::CancelCall>,
1663    session: Session,
1664) -> Result<()> {
1665    let called_user_id = UserId::from_proto(request.called_user_id);
1666    let room_id = RoomId::from_proto(request.room_id);
1667    {
1668        let room = session
1669            .db()
1670            .await
1671            .cancel_call(room_id, session.connection_id, called_user_id)
1672            .await?;
1673        room_updated(&room, &session.peer);
1674    }
1675
1676    for connection_id in session
1677        .connection_pool()
1678        .await
1679        .user_connection_ids(called_user_id)
1680    {
1681        session
1682            .peer
1683            .send(
1684                connection_id,
1685                proto::CallCanceled {
1686                    room_id: room_id.to_proto(),
1687                },
1688            )
1689            .trace_err();
1690    }
1691    response.send(proto::Ack {})?;
1692
1693    update_user_contacts(called_user_id, &session).await?;
1694    Ok(())
1695}
1696
1697/// Decline an incoming call.
1698async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1699    let room_id = RoomId::from_proto(message.room_id);
1700    {
1701        let room = session
1702            .db()
1703            .await
1704            .decline_call(Some(room_id), session.user_id())
1705            .await?
1706            .ok_or_else(|| anyhow!("failed to decline call"))?;
1707        room_updated(&room, &session.peer);
1708    }
1709
1710    for connection_id in session
1711        .connection_pool()
1712        .await
1713        .user_connection_ids(session.user_id())
1714    {
1715        session
1716            .peer
1717            .send(
1718                connection_id,
1719                proto::CallCanceled {
1720                    room_id: room_id.to_proto(),
1721                },
1722            )
1723            .trace_err();
1724    }
1725    update_user_contacts(session.user_id(), &session).await?;
1726    Ok(())
1727}
1728
1729/// Updates other participants in the room with your current location.
1730async fn update_participant_location(
1731    request: proto::UpdateParticipantLocation,
1732    response: Response<proto::UpdateParticipantLocation>,
1733    session: Session,
1734) -> Result<()> {
1735    let room_id = RoomId::from_proto(request.room_id);
1736    let location = request
1737        .location
1738        .ok_or_else(|| anyhow!("invalid location"))?;
1739
1740    let db = session.db().await;
1741    let room = db
1742        .update_room_participant_location(room_id, session.connection_id, location)
1743        .await?;
1744
1745    room_updated(&room, &session.peer);
1746    response.send(proto::Ack {})?;
1747    Ok(())
1748}
1749
1750/// Share a project into the room.
1751async fn share_project(
1752    request: proto::ShareProject,
1753    response: Response<proto::ShareProject>,
1754    session: Session,
1755) -> Result<()> {
1756    let (project_id, room) = &*session
1757        .db()
1758        .await
1759        .share_project(
1760            RoomId::from_proto(request.room_id),
1761            session.connection_id,
1762            &request.worktrees,
1763            request.is_ssh_project,
1764        )
1765        .await?;
1766    response.send(proto::ShareProjectResponse {
1767        project_id: project_id.to_proto(),
1768    })?;
1769    room_updated(room, &session.peer);
1770
1771    Ok(())
1772}
1773
1774/// Unshare a project from the room.
1775async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1776    let project_id = ProjectId::from_proto(message.project_id);
1777    unshare_project_internal(project_id, session.connection_id, &session).await
1778}
1779
1780async fn unshare_project_internal(
1781    project_id: ProjectId,
1782    connection_id: ConnectionId,
1783    session: &Session,
1784) -> Result<()> {
1785    let delete = {
1786        let room_guard = session
1787            .db()
1788            .await
1789            .unshare_project(project_id, connection_id)
1790            .await?;
1791
1792        let (delete, room, guest_connection_ids) = &*room_guard;
1793
1794        let message = proto::UnshareProject {
1795            project_id: project_id.to_proto(),
1796        };
1797
1798        broadcast(
1799            Some(connection_id),
1800            guest_connection_ids.iter().copied(),
1801            |conn_id| session.peer.send(conn_id, message.clone()),
1802        );
1803        if let Some(room) = room {
1804            room_updated(room, &session.peer);
1805        }
1806
1807        *delete
1808    };
1809
1810    if delete {
1811        let db = session.db().await;
1812        db.delete_project(project_id).await?;
1813    }
1814
1815    Ok(())
1816}
1817
1818/// Join someone elses shared project.
1819async fn join_project(
1820    request: proto::JoinProject,
1821    response: Response<proto::JoinProject>,
1822    session: Session,
1823) -> Result<()> {
1824    let project_id = ProjectId::from_proto(request.project_id);
1825
1826    tracing::info!(%project_id, "join project");
1827
1828    let db = session.db().await;
1829    let (project, replica_id) = &mut *db
1830        .join_project(project_id, session.connection_id, session.user_id())
1831        .await?;
1832    drop(db);
1833    tracing::info!(%project_id, "join remote project");
1834    join_project_internal(response, session, project, replica_id)
1835}
1836
1837trait JoinProjectInternalResponse {
1838    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1839}
1840impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1841    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1842        Response::<proto::JoinProject>::send(self, result)
1843    }
1844}
1845
1846fn join_project_internal(
1847    response: impl JoinProjectInternalResponse,
1848    session: Session,
1849    project: &mut Project,
1850    replica_id: &ReplicaId,
1851) -> Result<()> {
1852    let collaborators = project
1853        .collaborators
1854        .iter()
1855        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1856        .map(|collaborator| collaborator.to_proto())
1857        .collect::<Vec<_>>();
1858    let project_id = project.id;
1859    let guest_user_id = session.user_id();
1860
1861    let worktrees = project
1862        .worktrees
1863        .iter()
1864        .map(|(id, worktree)| proto::WorktreeMetadata {
1865            id: *id,
1866            root_name: worktree.root_name.clone(),
1867            visible: worktree.visible,
1868            abs_path: worktree.abs_path.clone(),
1869        })
1870        .collect::<Vec<_>>();
1871
1872    let add_project_collaborator = proto::AddProjectCollaborator {
1873        project_id: project_id.to_proto(),
1874        collaborator: Some(proto::Collaborator {
1875            peer_id: Some(session.connection_id.into()),
1876            replica_id: replica_id.0 as u32,
1877            user_id: guest_user_id.to_proto(),
1878            is_host: false,
1879        }),
1880    };
1881
1882    for collaborator in &collaborators {
1883        session
1884            .peer
1885            .send(
1886                collaborator.peer_id.unwrap().into(),
1887                add_project_collaborator.clone(),
1888            )
1889            .trace_err();
1890    }
1891
1892    // First, we send the metadata associated with each worktree.
1893    response.send(proto::JoinProjectResponse {
1894        project_id: project.id.0 as u64,
1895        worktrees: worktrees.clone(),
1896        replica_id: replica_id.0 as u32,
1897        collaborators: collaborators.clone(),
1898        language_servers: project.language_servers.clone(),
1899        role: project.role.into(),
1900    })?;
1901
1902    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1903        // Stream this worktree's entries.
1904        let message = proto::UpdateWorktree {
1905            project_id: project_id.to_proto(),
1906            worktree_id,
1907            abs_path: worktree.abs_path.clone(),
1908            root_name: worktree.root_name,
1909            updated_entries: worktree.entries,
1910            removed_entries: Default::default(),
1911            scan_id: worktree.scan_id,
1912            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1913            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1914            removed_repositories: Default::default(),
1915        };
1916        for update in proto::split_worktree_update(message) {
1917            session.peer.send(session.connection_id, update.clone())?;
1918        }
1919
1920        // Stream this worktree's diagnostics.
1921        for summary in worktree.diagnostic_summaries {
1922            session.peer.send(
1923                session.connection_id,
1924                proto::UpdateDiagnosticSummary {
1925                    project_id: project_id.to_proto(),
1926                    worktree_id: worktree.id,
1927                    summary: Some(summary),
1928                },
1929            )?;
1930        }
1931
1932        for settings_file in worktree.settings_files {
1933            session.peer.send(
1934                session.connection_id,
1935                proto::UpdateWorktreeSettings {
1936                    project_id: project_id.to_proto(),
1937                    worktree_id: worktree.id,
1938                    path: settings_file.path,
1939                    content: Some(settings_file.content),
1940                    kind: Some(settings_file.kind.to_proto() as i32),
1941                },
1942            )?;
1943        }
1944    }
1945
1946    for repository in mem::take(&mut project.repositories) {
1947        for update in split_repository_update(repository) {
1948            session.peer.send(session.connection_id, update)?;
1949        }
1950    }
1951
1952    for language_server in &project.language_servers {
1953        session.peer.send(
1954            session.connection_id,
1955            proto::UpdateLanguageServer {
1956                project_id: project_id.to_proto(),
1957                language_server_id: language_server.id,
1958                variant: Some(
1959                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1960                        proto::LspDiskBasedDiagnosticsUpdated {},
1961                    ),
1962                ),
1963            },
1964        )?;
1965    }
1966
1967    Ok(())
1968}
1969
1970/// Leave someone elses shared project.
1971async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1972    let sender_id = session.connection_id;
1973    let project_id = ProjectId::from_proto(request.project_id);
1974    let db = session.db().await;
1975
1976    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1977    tracing::info!(
1978        %project_id,
1979        "leave project"
1980    );
1981
1982    project_left(project, &session);
1983    if let Some(room) = room {
1984        room_updated(room, &session.peer);
1985    }
1986
1987    Ok(())
1988}
1989
1990/// Updates other participants with changes to the project
1991async fn update_project(
1992    request: proto::UpdateProject,
1993    response: Response<proto::UpdateProject>,
1994    session: Session,
1995) -> Result<()> {
1996    let project_id = ProjectId::from_proto(request.project_id);
1997    let (room, guest_connection_ids) = &*session
1998        .db()
1999        .await
2000        .update_project(project_id, session.connection_id, &request.worktrees)
2001        .await?;
2002    broadcast(
2003        Some(session.connection_id),
2004        guest_connection_ids.iter().copied(),
2005        |connection_id| {
2006            session
2007                .peer
2008                .forward_send(session.connection_id, connection_id, request.clone())
2009        },
2010    );
2011    if let Some(room) = room {
2012        room_updated(room, &session.peer);
2013    }
2014    response.send(proto::Ack {})?;
2015
2016    Ok(())
2017}
2018
2019/// Updates other participants with changes to the worktree
2020async fn update_worktree(
2021    request: proto::UpdateWorktree,
2022    response: Response<proto::UpdateWorktree>,
2023    session: Session,
2024) -> Result<()> {
2025    let guest_connection_ids = session
2026        .db()
2027        .await
2028        .update_worktree(&request, session.connection_id)
2029        .await?;
2030
2031    broadcast(
2032        Some(session.connection_id),
2033        guest_connection_ids.iter().copied(),
2034        |connection_id| {
2035            session
2036                .peer
2037                .forward_send(session.connection_id, connection_id, request.clone())
2038        },
2039    );
2040    response.send(proto::Ack {})?;
2041    Ok(())
2042}
2043
2044async fn update_repository(
2045    request: proto::UpdateRepository,
2046    response: Response<proto::UpdateRepository>,
2047    session: Session,
2048) -> Result<()> {
2049    let guest_connection_ids = session
2050        .db()
2051        .await
2052        .update_repository(&request, session.connection_id)
2053        .await?;
2054
2055    broadcast(
2056        Some(session.connection_id),
2057        guest_connection_ids.iter().copied(),
2058        |connection_id| {
2059            session
2060                .peer
2061                .forward_send(session.connection_id, connection_id, request.clone())
2062        },
2063    );
2064    response.send(proto::Ack {})?;
2065    Ok(())
2066}
2067
2068async fn remove_repository(
2069    request: proto::RemoveRepository,
2070    response: Response<proto::RemoveRepository>,
2071    session: Session,
2072) -> Result<()> {
2073    let guest_connection_ids = session
2074        .db()
2075        .await
2076        .remove_repository(&request, session.connection_id)
2077        .await?;
2078
2079    broadcast(
2080        Some(session.connection_id),
2081        guest_connection_ids.iter().copied(),
2082        |connection_id| {
2083            session
2084                .peer
2085                .forward_send(session.connection_id, connection_id, request.clone())
2086        },
2087    );
2088    response.send(proto::Ack {})?;
2089    Ok(())
2090}
2091
2092/// Updates other participants with changes to the diagnostics
2093async fn update_diagnostic_summary(
2094    message: proto::UpdateDiagnosticSummary,
2095    session: Session,
2096) -> Result<()> {
2097    let guest_connection_ids = session
2098        .db()
2099        .await
2100        .update_diagnostic_summary(&message, session.connection_id)
2101        .await?;
2102
2103    broadcast(
2104        Some(session.connection_id),
2105        guest_connection_ids.iter().copied(),
2106        |connection_id| {
2107            session
2108                .peer
2109                .forward_send(session.connection_id, connection_id, message.clone())
2110        },
2111    );
2112
2113    Ok(())
2114}
2115
2116/// Updates other participants with changes to the worktree settings
2117async fn update_worktree_settings(
2118    message: proto::UpdateWorktreeSettings,
2119    session: Session,
2120) -> Result<()> {
2121    let guest_connection_ids = session
2122        .db()
2123        .await
2124        .update_worktree_settings(&message, session.connection_id)
2125        .await?;
2126
2127    broadcast(
2128        Some(session.connection_id),
2129        guest_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, message.clone())
2134        },
2135    );
2136
2137    Ok(())
2138}
2139
2140/// Notify other participants that a language server has started.
2141async fn start_language_server(
2142    request: proto::StartLanguageServer,
2143    session: Session,
2144) -> Result<()> {
2145    let guest_connection_ids = session
2146        .db()
2147        .await
2148        .start_language_server(&request, session.connection_id)
2149        .await?;
2150
2151    broadcast(
2152        Some(session.connection_id),
2153        guest_connection_ids.iter().copied(),
2154        |connection_id| {
2155            session
2156                .peer
2157                .forward_send(session.connection_id, connection_id, request.clone())
2158        },
2159    );
2160    Ok(())
2161}
2162
2163/// Notify other participants that a language server has changed.
2164async fn update_language_server(
2165    request: proto::UpdateLanguageServer,
2166    session: Session,
2167) -> Result<()> {
2168    let project_id = ProjectId::from_proto(request.project_id);
2169    let project_connection_ids = session
2170        .db()
2171        .await
2172        .project_connection_ids(project_id, session.connection_id, true)
2173        .await?;
2174    broadcast(
2175        Some(session.connection_id),
2176        project_connection_ids.iter().copied(),
2177        |connection_id| {
2178            session
2179                .peer
2180                .forward_send(session.connection_id, connection_id, request.clone())
2181        },
2182    );
2183    Ok(())
2184}
2185
2186/// forward a project request to the host. These requests should be read only
2187/// as guests are allowed to send them.
2188async fn forward_read_only_project_request<T>(
2189    request: T,
2190    response: Response<T>,
2191    session: Session,
2192) -> Result<()>
2193where
2194    T: EntityMessage + RequestMessage,
2195{
2196    let project_id = ProjectId::from_proto(request.remote_entity_id());
2197    let host_connection_id = session
2198        .db()
2199        .await
2200        .host_for_read_only_project_request(project_id, session.connection_id)
2201        .await?;
2202    let payload = session
2203        .peer
2204        .forward_request(session.connection_id, host_connection_id, request)
2205        .await?;
2206    response.send(payload)?;
2207    Ok(())
2208}
2209
2210async fn forward_find_search_candidates_request(
2211    request: proto::FindSearchCandidates,
2212    response: Response<proto::FindSearchCandidates>,
2213    session: Session,
2214) -> Result<()> {
2215    let project_id = ProjectId::from_proto(request.remote_entity_id());
2216    let host_connection_id = session
2217        .db()
2218        .await
2219        .host_for_read_only_project_request(project_id, session.connection_id)
2220        .await?;
2221    let payload = session
2222        .peer
2223        .forward_request(session.connection_id, host_connection_id, request)
2224        .await?;
2225    response.send(payload)?;
2226    Ok(())
2227}
2228
2229/// forward a project request to the host. These requests are disallowed
2230/// for guests.
2231async fn forward_mutating_project_request<T>(
2232    request: T,
2233    response: Response<T>,
2234    session: Session,
2235) -> Result<()>
2236where
2237    T: EntityMessage + RequestMessage,
2238{
2239    let project_id = ProjectId::from_proto(request.remote_entity_id());
2240
2241    let host_connection_id = session
2242        .db()
2243        .await
2244        .host_for_mutating_project_request(project_id, session.connection_id)
2245        .await?;
2246    let payload = session
2247        .peer
2248        .forward_request(session.connection_id, host_connection_id, request)
2249        .await?;
2250    response.send(payload)?;
2251    Ok(())
2252}
2253
2254/// Notify other participants that a new buffer has been created
2255async fn create_buffer_for_peer(
2256    request: proto::CreateBufferForPeer,
2257    session: Session,
2258) -> Result<()> {
2259    session
2260        .db()
2261        .await
2262        .check_user_is_project_host(
2263            ProjectId::from_proto(request.project_id),
2264            session.connection_id,
2265        )
2266        .await?;
2267    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2268    session
2269        .peer
2270        .forward_send(session.connection_id, peer_id.into(), request)?;
2271    Ok(())
2272}
2273
2274/// Notify other participants that a buffer has been updated. This is
2275/// allowed for guests as long as the update is limited to selections.
2276async fn update_buffer(
2277    request: proto::UpdateBuffer,
2278    response: Response<proto::UpdateBuffer>,
2279    session: Session,
2280) -> Result<()> {
2281    let project_id = ProjectId::from_proto(request.project_id);
2282    let mut capability = Capability::ReadOnly;
2283
2284    for op in request.operations.iter() {
2285        match op.variant {
2286            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2287            Some(_) => capability = Capability::ReadWrite,
2288        }
2289    }
2290
2291    let host = {
2292        let guard = session
2293            .db()
2294            .await
2295            .connections_for_buffer_update(project_id, session.connection_id, capability)
2296            .await?;
2297
2298        let (host, guests) = &*guard;
2299
2300        broadcast(
2301            Some(session.connection_id),
2302            guests.clone(),
2303            |connection_id| {
2304                session
2305                    .peer
2306                    .forward_send(session.connection_id, connection_id, request.clone())
2307            },
2308        );
2309
2310        *host
2311    };
2312
2313    if host != session.connection_id {
2314        session
2315            .peer
2316            .forward_request(session.connection_id, host, request.clone())
2317            .await?;
2318    }
2319
2320    response.send(proto::Ack {})?;
2321    Ok(())
2322}
2323
2324async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2325    let project_id = ProjectId::from_proto(message.project_id);
2326
2327    let operation = message.operation.as_ref().context("invalid operation")?;
2328    let capability = match operation.variant.as_ref() {
2329        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2330            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2331                match buffer_op.variant {
2332                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2333                        Capability::ReadOnly
2334                    }
2335                    _ => Capability::ReadWrite,
2336                }
2337            } else {
2338                Capability::ReadWrite
2339            }
2340        }
2341        Some(_) => Capability::ReadWrite,
2342        None => Capability::ReadOnly,
2343    };
2344
2345    let guard = session
2346        .db()
2347        .await
2348        .connections_for_buffer_update(project_id, session.connection_id, capability)
2349        .await?;
2350
2351    let (host, guests) = &*guard;
2352
2353    broadcast(
2354        Some(session.connection_id),
2355        guests.iter().chain([host]).copied(),
2356        |connection_id| {
2357            session
2358                .peer
2359                .forward_send(session.connection_id, connection_id, message.clone())
2360        },
2361    );
2362
2363    Ok(())
2364}
2365
2366/// Notify other participants that a project has been updated.
2367async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2368    request: T,
2369    session: Session,
2370) -> Result<()> {
2371    let project_id = ProjectId::from_proto(request.remote_entity_id());
2372    let project_connection_ids = session
2373        .db()
2374        .await
2375        .project_connection_ids(project_id, session.connection_id, false)
2376        .await?;
2377
2378    broadcast(
2379        Some(session.connection_id),
2380        project_connection_ids.iter().copied(),
2381        |connection_id| {
2382            session
2383                .peer
2384                .forward_send(session.connection_id, connection_id, request.clone())
2385        },
2386    );
2387    Ok(())
2388}
2389
2390/// Start following another user in a call.
2391async fn follow(
2392    request: proto::Follow,
2393    response: Response<proto::Follow>,
2394    session: Session,
2395) -> Result<()> {
2396    let room_id = RoomId::from_proto(request.room_id);
2397    let project_id = request.project_id.map(ProjectId::from_proto);
2398    let leader_id = request
2399        .leader_id
2400        .ok_or_else(|| anyhow!("invalid leader id"))?
2401        .into();
2402    let follower_id = session.connection_id;
2403
2404    session
2405        .db()
2406        .await
2407        .check_room_participants(room_id, leader_id, session.connection_id)
2408        .await?;
2409
2410    let response_payload = session
2411        .peer
2412        .forward_request(session.connection_id, leader_id, request)
2413        .await?;
2414    response.send(response_payload)?;
2415
2416    if let Some(project_id) = project_id {
2417        let room = session
2418            .db()
2419            .await
2420            .follow(room_id, project_id, leader_id, follower_id)
2421            .await?;
2422        room_updated(&room, &session.peer);
2423    }
2424
2425    Ok(())
2426}
2427
2428/// Stop following another user in a call.
2429async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2430    let room_id = RoomId::from_proto(request.room_id);
2431    let project_id = request.project_id.map(ProjectId::from_proto);
2432    let leader_id = request
2433        .leader_id
2434        .ok_or_else(|| anyhow!("invalid leader id"))?
2435        .into();
2436    let follower_id = session.connection_id;
2437
2438    session
2439        .db()
2440        .await
2441        .check_room_participants(room_id, leader_id, session.connection_id)
2442        .await?;
2443
2444    session
2445        .peer
2446        .forward_send(session.connection_id, leader_id, request)?;
2447
2448    if let Some(project_id) = project_id {
2449        let room = session
2450            .db()
2451            .await
2452            .unfollow(room_id, project_id, leader_id, follower_id)
2453            .await?;
2454        room_updated(&room, &session.peer);
2455    }
2456
2457    Ok(())
2458}
2459
2460/// Notify everyone following you of your current location.
2461async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2462    let room_id = RoomId::from_proto(request.room_id);
2463    let database = session.db.lock().await;
2464
2465    let connection_ids = if let Some(project_id) = request.project_id {
2466        let project_id = ProjectId::from_proto(project_id);
2467        database
2468            .project_connection_ids(project_id, session.connection_id, true)
2469            .await?
2470    } else {
2471        database
2472            .room_connection_ids(room_id, session.connection_id)
2473            .await?
2474    };
2475
2476    // For now, don't send view update messages back to that view's current leader.
2477    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2478        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2479        _ => None,
2480    });
2481
2482    for connection_id in connection_ids.iter().cloned() {
2483        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2484            session
2485                .peer
2486                .forward_send(session.connection_id, connection_id, request.clone())?;
2487        }
2488    }
2489    Ok(())
2490}
2491
2492/// Get public data about users.
2493async fn get_users(
2494    request: proto::GetUsers,
2495    response: Response<proto::GetUsers>,
2496    session: Session,
2497) -> Result<()> {
2498    let user_ids = request
2499        .user_ids
2500        .into_iter()
2501        .map(UserId::from_proto)
2502        .collect();
2503    let users = session
2504        .db()
2505        .await
2506        .get_users_by_ids(user_ids)
2507        .await?
2508        .into_iter()
2509        .map(|user| proto::User {
2510            id: user.id.to_proto(),
2511            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2512            github_login: user.github_login,
2513            email: user.email_address,
2514            name: user.name,
2515        })
2516        .collect();
2517    response.send(proto::UsersResponse { users })?;
2518    Ok(())
2519}
2520
2521/// Search for users (to invite) buy Github login
2522async fn fuzzy_search_users(
2523    request: proto::FuzzySearchUsers,
2524    response: Response<proto::FuzzySearchUsers>,
2525    session: Session,
2526) -> Result<()> {
2527    let query = request.query;
2528    let users = match query.len() {
2529        0 => vec![],
2530        1 | 2 => session
2531            .db()
2532            .await
2533            .get_user_by_github_login(&query)
2534            .await?
2535            .into_iter()
2536            .collect(),
2537        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2538    };
2539    let users = users
2540        .into_iter()
2541        .filter(|user| user.id != session.user_id())
2542        .map(|user| proto::User {
2543            id: user.id.to_proto(),
2544            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2545            github_login: user.github_login,
2546            name: user.name,
2547            email: user.email_address,
2548        })
2549        .collect();
2550    response.send(proto::UsersResponse { users })?;
2551    Ok(())
2552}
2553
2554/// Send a contact request to another user.
2555async fn request_contact(
2556    request: proto::RequestContact,
2557    response: Response<proto::RequestContact>,
2558    session: Session,
2559) -> Result<()> {
2560    let requester_id = session.user_id();
2561    let responder_id = UserId::from_proto(request.responder_id);
2562    if requester_id == responder_id {
2563        return Err(anyhow!("cannot add yourself as a contact"))?;
2564    }
2565
2566    let notifications = session
2567        .db()
2568        .await
2569        .send_contact_request(requester_id, responder_id)
2570        .await?;
2571
2572    // Update outgoing contact requests of requester
2573    let mut update = proto::UpdateContacts::default();
2574    update.outgoing_requests.push(responder_id.to_proto());
2575    for connection_id in session
2576        .connection_pool()
2577        .await
2578        .user_connection_ids(requester_id)
2579    {
2580        session.peer.send(connection_id, update.clone())?;
2581    }
2582
2583    // Update incoming contact requests of responder
2584    let mut update = proto::UpdateContacts::default();
2585    update
2586        .incoming_requests
2587        .push(proto::IncomingContactRequest {
2588            requester_id: requester_id.to_proto(),
2589        });
2590    let connection_pool = session.connection_pool().await;
2591    for connection_id in connection_pool.user_connection_ids(responder_id) {
2592        session.peer.send(connection_id, update.clone())?;
2593    }
2594
2595    send_notifications(&connection_pool, &session.peer, notifications);
2596
2597    response.send(proto::Ack {})?;
2598    Ok(())
2599}
2600
2601/// Accept or decline a contact request
2602async fn respond_to_contact_request(
2603    request: proto::RespondToContactRequest,
2604    response: Response<proto::RespondToContactRequest>,
2605    session: Session,
2606) -> Result<()> {
2607    let responder_id = session.user_id();
2608    let requester_id = UserId::from_proto(request.requester_id);
2609    let db = session.db().await;
2610    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2611        db.dismiss_contact_notification(responder_id, requester_id)
2612            .await?;
2613    } else {
2614        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2615
2616        let notifications = db
2617            .respond_to_contact_request(responder_id, requester_id, accept)
2618            .await?;
2619        let requester_busy = db.is_user_busy(requester_id).await?;
2620        let responder_busy = db.is_user_busy(responder_id).await?;
2621
2622        let pool = session.connection_pool().await;
2623        // Update responder with new contact
2624        let mut update = proto::UpdateContacts::default();
2625        if accept {
2626            update
2627                .contacts
2628                .push(contact_for_user(requester_id, requester_busy, &pool));
2629        }
2630        update
2631            .remove_incoming_requests
2632            .push(requester_id.to_proto());
2633        for connection_id in pool.user_connection_ids(responder_id) {
2634            session.peer.send(connection_id, update.clone())?;
2635        }
2636
2637        // Update requester with new contact
2638        let mut update = proto::UpdateContacts::default();
2639        if accept {
2640            update
2641                .contacts
2642                .push(contact_for_user(responder_id, responder_busy, &pool));
2643        }
2644        update
2645            .remove_outgoing_requests
2646            .push(responder_id.to_proto());
2647
2648        for connection_id in pool.user_connection_ids(requester_id) {
2649            session.peer.send(connection_id, update.clone())?;
2650        }
2651
2652        send_notifications(&pool, &session.peer, notifications);
2653    }
2654
2655    response.send(proto::Ack {})?;
2656    Ok(())
2657}
2658
2659/// Remove a contact.
2660async fn remove_contact(
2661    request: proto::RemoveContact,
2662    response: Response<proto::RemoveContact>,
2663    session: Session,
2664) -> Result<()> {
2665    let requester_id = session.user_id();
2666    let responder_id = UserId::from_proto(request.user_id);
2667    let db = session.db().await;
2668    let (contact_accepted, deleted_notification_id) =
2669        db.remove_contact(requester_id, responder_id).await?;
2670
2671    let pool = session.connection_pool().await;
2672    // Update outgoing contact requests of requester
2673    let mut update = proto::UpdateContacts::default();
2674    if contact_accepted {
2675        update.remove_contacts.push(responder_id.to_proto());
2676    } else {
2677        update
2678            .remove_outgoing_requests
2679            .push(responder_id.to_proto());
2680    }
2681    for connection_id in pool.user_connection_ids(requester_id) {
2682        session.peer.send(connection_id, update.clone())?;
2683    }
2684
2685    // Update incoming contact requests of responder
2686    let mut update = proto::UpdateContacts::default();
2687    if contact_accepted {
2688        update.remove_contacts.push(requester_id.to_proto());
2689    } else {
2690        update
2691            .remove_incoming_requests
2692            .push(requester_id.to_proto());
2693    }
2694    for connection_id in pool.user_connection_ids(responder_id) {
2695        session.peer.send(connection_id, update.clone())?;
2696        if let Some(notification_id) = deleted_notification_id {
2697            session.peer.send(
2698                connection_id,
2699                proto::DeleteNotification {
2700                    notification_id: notification_id.to_proto(),
2701                },
2702            )?;
2703        }
2704    }
2705
2706    response.send(proto::Ack {})?;
2707    Ok(())
2708}
2709
2710fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2711    version.0.minor() < 139
2712}
2713
2714async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2715    let plan = session.current_plan(&session.db().await).await?;
2716
2717    session
2718        .peer
2719        .send(
2720            session.connection_id,
2721            proto::UpdateUserPlan { plan: plan.into() },
2722        )
2723        .trace_err();
2724
2725    Ok(())
2726}
2727
2728async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2729    subscribe_user_to_channels(session.user_id(), &session).await?;
2730    Ok(())
2731}
2732
2733async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2734    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2735    let mut pool = session.connection_pool().await;
2736    for membership in &channels_for_user.channel_memberships {
2737        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2738    }
2739    session.peer.send(
2740        session.connection_id,
2741        build_update_user_channels(&channels_for_user),
2742    )?;
2743    session.peer.send(
2744        session.connection_id,
2745        build_channels_update(channels_for_user),
2746    )?;
2747    Ok(())
2748}
2749
2750/// Creates a new channel.
2751async fn create_channel(
2752    request: proto::CreateChannel,
2753    response: Response<proto::CreateChannel>,
2754    session: Session,
2755) -> Result<()> {
2756    let db = session.db().await;
2757
2758    let parent_id = request.parent_id.map(ChannelId::from_proto);
2759    let (channel, membership) = db
2760        .create_channel(&request.name, parent_id, session.user_id())
2761        .await?;
2762
2763    let root_id = channel.root_id();
2764    let channel = Channel::from_model(channel);
2765
2766    response.send(proto::CreateChannelResponse {
2767        channel: Some(channel.to_proto()),
2768        parent_id: request.parent_id,
2769    })?;
2770
2771    let mut connection_pool = session.connection_pool().await;
2772    if let Some(membership) = membership {
2773        connection_pool.subscribe_to_channel(
2774            membership.user_id,
2775            membership.channel_id,
2776            membership.role,
2777        );
2778        let update = proto::UpdateUserChannels {
2779            channel_memberships: vec![proto::ChannelMembership {
2780                channel_id: membership.channel_id.to_proto(),
2781                role: membership.role.into(),
2782            }],
2783            ..Default::default()
2784        };
2785        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2786            session.peer.send(connection_id, update.clone())?;
2787        }
2788    }
2789
2790    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2791        if !role.can_see_channel(channel.visibility) {
2792            continue;
2793        }
2794
2795        let update = proto::UpdateChannels {
2796            channels: vec![channel.to_proto()],
2797            ..Default::default()
2798        };
2799        session.peer.send(connection_id, update.clone())?;
2800    }
2801
2802    Ok(())
2803}
2804
2805/// Delete a channel
2806async fn delete_channel(
2807    request: proto::DeleteChannel,
2808    response: Response<proto::DeleteChannel>,
2809    session: Session,
2810) -> Result<()> {
2811    let db = session.db().await;
2812
2813    let channel_id = request.channel_id;
2814    let (root_channel, removed_channels) = db
2815        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2816        .await?;
2817    response.send(proto::Ack {})?;
2818
2819    // Notify members of removed channels
2820    let mut update = proto::UpdateChannels::default();
2821    update
2822        .delete_channels
2823        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2824
2825    let connection_pool = session.connection_pool().await;
2826    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2827        session.peer.send(connection_id, update.clone())?;
2828    }
2829
2830    Ok(())
2831}
2832
2833/// Invite someone to join a channel.
2834async fn invite_channel_member(
2835    request: proto::InviteChannelMember,
2836    response: Response<proto::InviteChannelMember>,
2837    session: Session,
2838) -> Result<()> {
2839    let db = session.db().await;
2840    let channel_id = ChannelId::from_proto(request.channel_id);
2841    let invitee_id = UserId::from_proto(request.user_id);
2842    let InviteMemberResult {
2843        channel,
2844        notifications,
2845    } = db
2846        .invite_channel_member(
2847            channel_id,
2848            invitee_id,
2849            session.user_id(),
2850            request.role().into(),
2851        )
2852        .await?;
2853
2854    let update = proto::UpdateChannels {
2855        channel_invitations: vec![channel.to_proto()],
2856        ..Default::default()
2857    };
2858
2859    let connection_pool = session.connection_pool().await;
2860    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2861        session.peer.send(connection_id, update.clone())?;
2862    }
2863
2864    send_notifications(&connection_pool, &session.peer, notifications);
2865
2866    response.send(proto::Ack {})?;
2867    Ok(())
2868}
2869
2870/// remove someone from a channel
2871async fn remove_channel_member(
2872    request: proto::RemoveChannelMember,
2873    response: Response<proto::RemoveChannelMember>,
2874    session: Session,
2875) -> Result<()> {
2876    let db = session.db().await;
2877    let channel_id = ChannelId::from_proto(request.channel_id);
2878    let member_id = UserId::from_proto(request.user_id);
2879
2880    let RemoveChannelMemberResult {
2881        membership_update,
2882        notification_id,
2883    } = db
2884        .remove_channel_member(channel_id, member_id, session.user_id())
2885        .await?;
2886
2887    let mut connection_pool = session.connection_pool().await;
2888    notify_membership_updated(
2889        &mut connection_pool,
2890        membership_update,
2891        member_id,
2892        &session.peer,
2893    );
2894    for connection_id in connection_pool.user_connection_ids(member_id) {
2895        if let Some(notification_id) = notification_id {
2896            session
2897                .peer
2898                .send(
2899                    connection_id,
2900                    proto::DeleteNotification {
2901                        notification_id: notification_id.to_proto(),
2902                    },
2903                )
2904                .trace_err();
2905        }
2906    }
2907
2908    response.send(proto::Ack {})?;
2909    Ok(())
2910}
2911
2912/// Toggle the channel between public and private.
2913/// Care is taken to maintain the invariant that public channels only descend from public channels,
2914/// (though members-only channels can appear at any point in the hierarchy).
2915async fn set_channel_visibility(
2916    request: proto::SetChannelVisibility,
2917    response: Response<proto::SetChannelVisibility>,
2918    session: Session,
2919) -> Result<()> {
2920    let db = session.db().await;
2921    let channel_id = ChannelId::from_proto(request.channel_id);
2922    let visibility = request.visibility().into();
2923
2924    let channel_model = db
2925        .set_channel_visibility(channel_id, visibility, session.user_id())
2926        .await?;
2927    let root_id = channel_model.root_id();
2928    let channel = Channel::from_model(channel_model);
2929
2930    let mut connection_pool = session.connection_pool().await;
2931    for (user_id, role) in connection_pool
2932        .channel_user_ids(root_id)
2933        .collect::<Vec<_>>()
2934        .into_iter()
2935    {
2936        let update = if role.can_see_channel(channel.visibility) {
2937            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2938            proto::UpdateChannels {
2939                channels: vec![channel.to_proto()],
2940                ..Default::default()
2941            }
2942        } else {
2943            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2944            proto::UpdateChannels {
2945                delete_channels: vec![channel.id.to_proto()],
2946                ..Default::default()
2947            }
2948        };
2949
2950        for connection_id in connection_pool.user_connection_ids(user_id) {
2951            session.peer.send(connection_id, update.clone())?;
2952        }
2953    }
2954
2955    response.send(proto::Ack {})?;
2956    Ok(())
2957}
2958
2959/// Alter the role for a user in the channel.
2960async fn set_channel_member_role(
2961    request: proto::SetChannelMemberRole,
2962    response: Response<proto::SetChannelMemberRole>,
2963    session: Session,
2964) -> Result<()> {
2965    let db = session.db().await;
2966    let channel_id = ChannelId::from_proto(request.channel_id);
2967    let member_id = UserId::from_proto(request.user_id);
2968    let result = db
2969        .set_channel_member_role(
2970            channel_id,
2971            session.user_id(),
2972            member_id,
2973            request.role().into(),
2974        )
2975        .await?;
2976
2977    match result {
2978        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2979            let mut connection_pool = session.connection_pool().await;
2980            notify_membership_updated(
2981                &mut connection_pool,
2982                membership_update,
2983                member_id,
2984                &session.peer,
2985            )
2986        }
2987        db::SetMemberRoleResult::InviteUpdated(channel) => {
2988            let update = proto::UpdateChannels {
2989                channel_invitations: vec![channel.to_proto()],
2990                ..Default::default()
2991            };
2992
2993            for connection_id in session
2994                .connection_pool()
2995                .await
2996                .user_connection_ids(member_id)
2997            {
2998                session.peer.send(connection_id, update.clone())?;
2999            }
3000        }
3001    }
3002
3003    response.send(proto::Ack {})?;
3004    Ok(())
3005}
3006
3007/// Change the name of a channel
3008async fn rename_channel(
3009    request: proto::RenameChannel,
3010    response: Response<proto::RenameChannel>,
3011    session: Session,
3012) -> Result<()> {
3013    let db = session.db().await;
3014    let channel_id = ChannelId::from_proto(request.channel_id);
3015    let channel_model = db
3016        .rename_channel(channel_id, session.user_id(), &request.name)
3017        .await?;
3018    let root_id = channel_model.root_id();
3019    let channel = Channel::from_model(channel_model);
3020
3021    response.send(proto::RenameChannelResponse {
3022        channel: Some(channel.to_proto()),
3023    })?;
3024
3025    let connection_pool = session.connection_pool().await;
3026    let update = proto::UpdateChannels {
3027        channels: vec![channel.to_proto()],
3028        ..Default::default()
3029    };
3030    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3031        if role.can_see_channel(channel.visibility) {
3032            session.peer.send(connection_id, update.clone())?;
3033        }
3034    }
3035
3036    Ok(())
3037}
3038
3039/// Move a channel to a new parent.
3040async fn move_channel(
3041    request: proto::MoveChannel,
3042    response: Response<proto::MoveChannel>,
3043    session: Session,
3044) -> Result<()> {
3045    let channel_id = ChannelId::from_proto(request.channel_id);
3046    let to = ChannelId::from_proto(request.to);
3047
3048    let (root_id, channels) = session
3049        .db()
3050        .await
3051        .move_channel(channel_id, to, session.user_id())
3052        .await?;
3053
3054    let connection_pool = session.connection_pool().await;
3055    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3056        let channels = channels
3057            .iter()
3058            .filter_map(|channel| {
3059                if role.can_see_channel(channel.visibility) {
3060                    Some(channel.to_proto())
3061                } else {
3062                    None
3063                }
3064            })
3065            .collect::<Vec<_>>();
3066        if channels.is_empty() {
3067            continue;
3068        }
3069
3070        let update = proto::UpdateChannels {
3071            channels,
3072            ..Default::default()
3073        };
3074
3075        session.peer.send(connection_id, update.clone())?;
3076    }
3077
3078    response.send(Ack {})?;
3079    Ok(())
3080}
3081
3082/// Get the list of channel members
3083async fn get_channel_members(
3084    request: proto::GetChannelMembers,
3085    response: Response<proto::GetChannelMembers>,
3086    session: Session,
3087) -> Result<()> {
3088    let db = session.db().await;
3089    let channel_id = ChannelId::from_proto(request.channel_id);
3090    let limit = if request.limit == 0 {
3091        u16::MAX as u64
3092    } else {
3093        request.limit
3094    };
3095    let (members, users) = db
3096        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3097        .await?;
3098    response.send(proto::GetChannelMembersResponse { members, users })?;
3099    Ok(())
3100}
3101
3102/// Accept or decline a channel invitation.
3103async fn respond_to_channel_invite(
3104    request: proto::RespondToChannelInvite,
3105    response: Response<proto::RespondToChannelInvite>,
3106    session: Session,
3107) -> Result<()> {
3108    let db = session.db().await;
3109    let channel_id = ChannelId::from_proto(request.channel_id);
3110    let RespondToChannelInvite {
3111        membership_update,
3112        notifications,
3113    } = db
3114        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3115        .await?;
3116
3117    let mut connection_pool = session.connection_pool().await;
3118    if let Some(membership_update) = membership_update {
3119        notify_membership_updated(
3120            &mut connection_pool,
3121            membership_update,
3122            session.user_id(),
3123            &session.peer,
3124        );
3125    } else {
3126        let update = proto::UpdateChannels {
3127            remove_channel_invitations: vec![channel_id.to_proto()],
3128            ..Default::default()
3129        };
3130
3131        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3132            session.peer.send(connection_id, update.clone())?;
3133        }
3134    };
3135
3136    send_notifications(&connection_pool, &session.peer, notifications);
3137
3138    response.send(proto::Ack {})?;
3139
3140    Ok(())
3141}
3142
3143/// Join the channels' room
3144async fn join_channel(
3145    request: proto::JoinChannel,
3146    response: Response<proto::JoinChannel>,
3147    session: Session,
3148) -> Result<()> {
3149    let channel_id = ChannelId::from_proto(request.channel_id);
3150    join_channel_internal(channel_id, Box::new(response), session).await
3151}
3152
3153trait JoinChannelInternalResponse {
3154    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3155}
3156impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3157    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3158        Response::<proto::JoinChannel>::send(self, result)
3159    }
3160}
3161impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3162    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3163        Response::<proto::JoinRoom>::send(self, result)
3164    }
3165}
3166
3167async fn join_channel_internal(
3168    channel_id: ChannelId,
3169    response: Box<impl JoinChannelInternalResponse>,
3170    session: Session,
3171) -> Result<()> {
3172    let joined_room = {
3173        let mut db = session.db().await;
3174        // If zed quits without leaving the room, and the user re-opens zed before the
3175        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3176        // room they were in.
3177        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3178            tracing::info!(
3179                stale_connection_id = %connection,
3180                "cleaning up stale connection",
3181            );
3182            drop(db);
3183            leave_room_for_session(&session, connection).await?;
3184            db = session.db().await;
3185        }
3186
3187        let (joined_room, membership_updated, role) = db
3188            .join_channel(channel_id, session.user_id(), session.connection_id)
3189            .await?;
3190
3191        let live_kit_connection_info =
3192            session
3193                .app_state
3194                .livekit_client
3195                .as_ref()
3196                .and_then(|live_kit| {
3197                    let (can_publish, token) = if role == ChannelRole::Guest {
3198                        (
3199                            false,
3200                            live_kit
3201                                .guest_token(
3202                                    &joined_room.room.livekit_room,
3203                                    &session.user_id().to_string(),
3204                                )
3205                                .trace_err()?,
3206                        )
3207                    } else {
3208                        (
3209                            true,
3210                            live_kit
3211                                .room_token(
3212                                    &joined_room.room.livekit_room,
3213                                    &session.user_id().to_string(),
3214                                )
3215                                .trace_err()?,
3216                        )
3217                    };
3218
3219                    Some(LiveKitConnectionInfo {
3220                        server_url: live_kit.url().into(),
3221                        token,
3222                        can_publish,
3223                    })
3224                });
3225
3226        response.send(proto::JoinRoomResponse {
3227            room: Some(joined_room.room.clone()),
3228            channel_id: joined_room
3229                .channel
3230                .as_ref()
3231                .map(|channel| channel.id.to_proto()),
3232            live_kit_connection_info,
3233        })?;
3234
3235        let mut connection_pool = session.connection_pool().await;
3236        if let Some(membership_updated) = membership_updated {
3237            notify_membership_updated(
3238                &mut connection_pool,
3239                membership_updated,
3240                session.user_id(),
3241                &session.peer,
3242            );
3243        }
3244
3245        room_updated(&joined_room.room, &session.peer);
3246
3247        joined_room
3248    };
3249
3250    channel_updated(
3251        &joined_room
3252            .channel
3253            .ok_or_else(|| anyhow!("channel not returned"))?,
3254        &joined_room.room,
3255        &session.peer,
3256        &*session.connection_pool().await,
3257    );
3258
3259    update_user_contacts(session.user_id(), &session).await?;
3260    Ok(())
3261}
3262
3263/// Start editing the channel notes
3264async fn join_channel_buffer(
3265    request: proto::JoinChannelBuffer,
3266    response: Response<proto::JoinChannelBuffer>,
3267    session: Session,
3268) -> Result<()> {
3269    let db = session.db().await;
3270    let channel_id = ChannelId::from_proto(request.channel_id);
3271
3272    let open_response = db
3273        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3274        .await?;
3275
3276    let collaborators = open_response.collaborators.clone();
3277    response.send(open_response)?;
3278
3279    let update = UpdateChannelBufferCollaborators {
3280        channel_id: channel_id.to_proto(),
3281        collaborators: collaborators.clone(),
3282    };
3283    channel_buffer_updated(
3284        session.connection_id,
3285        collaborators
3286            .iter()
3287            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3288        &update,
3289        &session.peer,
3290    );
3291
3292    Ok(())
3293}
3294
3295/// Edit the channel notes
3296async fn update_channel_buffer(
3297    request: proto::UpdateChannelBuffer,
3298    session: Session,
3299) -> Result<()> {
3300    let db = session.db().await;
3301    let channel_id = ChannelId::from_proto(request.channel_id);
3302
3303    let (collaborators, epoch, version) = db
3304        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3305        .await?;
3306
3307    channel_buffer_updated(
3308        session.connection_id,
3309        collaborators.clone(),
3310        &proto::UpdateChannelBuffer {
3311            channel_id: channel_id.to_proto(),
3312            operations: request.operations,
3313        },
3314        &session.peer,
3315    );
3316
3317    let pool = &*session.connection_pool().await;
3318
3319    let non_collaborators =
3320        pool.channel_connection_ids(channel_id)
3321            .filter_map(|(connection_id, _)| {
3322                if collaborators.contains(&connection_id) {
3323                    None
3324                } else {
3325                    Some(connection_id)
3326                }
3327            });
3328
3329    broadcast(None, non_collaborators, |peer_id| {
3330        session.peer.send(
3331            peer_id,
3332            proto::UpdateChannels {
3333                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3334                    channel_id: channel_id.to_proto(),
3335                    epoch: epoch as u64,
3336                    version: version.clone(),
3337                }],
3338                ..Default::default()
3339            },
3340        )
3341    });
3342
3343    Ok(())
3344}
3345
3346/// Rejoin the channel notes after a connection blip
3347async fn rejoin_channel_buffers(
3348    request: proto::RejoinChannelBuffers,
3349    response: Response<proto::RejoinChannelBuffers>,
3350    session: Session,
3351) -> Result<()> {
3352    let db = session.db().await;
3353    let buffers = db
3354        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3355        .await?;
3356
3357    for rejoined_buffer in &buffers {
3358        let collaborators_to_notify = rejoined_buffer
3359            .buffer
3360            .collaborators
3361            .iter()
3362            .filter_map(|c| Some(c.peer_id?.into()));
3363        channel_buffer_updated(
3364            session.connection_id,
3365            collaborators_to_notify,
3366            &proto::UpdateChannelBufferCollaborators {
3367                channel_id: rejoined_buffer.buffer.channel_id,
3368                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3369            },
3370            &session.peer,
3371        );
3372    }
3373
3374    response.send(proto::RejoinChannelBuffersResponse {
3375        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3376    })?;
3377
3378    Ok(())
3379}
3380
3381/// Stop editing the channel notes
3382async fn leave_channel_buffer(
3383    request: proto::LeaveChannelBuffer,
3384    response: Response<proto::LeaveChannelBuffer>,
3385    session: Session,
3386) -> Result<()> {
3387    let db = session.db().await;
3388    let channel_id = ChannelId::from_proto(request.channel_id);
3389
3390    let left_buffer = db
3391        .leave_channel_buffer(channel_id, session.connection_id)
3392        .await?;
3393
3394    response.send(Ack {})?;
3395
3396    channel_buffer_updated(
3397        session.connection_id,
3398        left_buffer.connections,
3399        &proto::UpdateChannelBufferCollaborators {
3400            channel_id: channel_id.to_proto(),
3401            collaborators: left_buffer.collaborators,
3402        },
3403        &session.peer,
3404    );
3405
3406    Ok(())
3407}
3408
3409fn channel_buffer_updated<T: EnvelopedMessage>(
3410    sender_id: ConnectionId,
3411    collaborators: impl IntoIterator<Item = ConnectionId>,
3412    message: &T,
3413    peer: &Peer,
3414) {
3415    broadcast(Some(sender_id), collaborators, |peer_id| {
3416        peer.send(peer_id, message.clone())
3417    });
3418}
3419
3420fn send_notifications(
3421    connection_pool: &ConnectionPool,
3422    peer: &Peer,
3423    notifications: db::NotificationBatch,
3424) {
3425    for (user_id, notification) in notifications {
3426        for connection_id in connection_pool.user_connection_ids(user_id) {
3427            if let Err(error) = peer.send(
3428                connection_id,
3429                proto::AddNotification {
3430                    notification: Some(notification.clone()),
3431                },
3432            ) {
3433                tracing::error!(
3434                    "failed to send notification to {:?} {}",
3435                    connection_id,
3436                    error
3437                );
3438            }
3439        }
3440    }
3441}
3442
3443/// Send a message to the channel
3444async fn send_channel_message(
3445    request: proto::SendChannelMessage,
3446    response: Response<proto::SendChannelMessage>,
3447    session: Session,
3448) -> Result<()> {
3449    // Validate the message body.
3450    let body = request.body.trim().to_string();
3451    if body.len() > MAX_MESSAGE_LEN {
3452        return Err(anyhow!("message is too long"))?;
3453    }
3454    if body.is_empty() {
3455        return Err(anyhow!("message can't be blank"))?;
3456    }
3457
3458    // TODO: adjust mentions if body is trimmed
3459
3460    let timestamp = OffsetDateTime::now_utc();
3461    let nonce = request
3462        .nonce
3463        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3464
3465    let channel_id = ChannelId::from_proto(request.channel_id);
3466    let CreatedChannelMessage {
3467        message_id,
3468        participant_connection_ids,
3469        notifications,
3470    } = session
3471        .db()
3472        .await
3473        .create_channel_message(
3474            channel_id,
3475            session.user_id(),
3476            &body,
3477            &request.mentions,
3478            timestamp,
3479            nonce.clone().into(),
3480            request.reply_to_message_id.map(MessageId::from_proto),
3481        )
3482        .await?;
3483
3484    let message = proto::ChannelMessage {
3485        sender_id: session.user_id().to_proto(),
3486        id: message_id.to_proto(),
3487        body,
3488        mentions: request.mentions,
3489        timestamp: timestamp.unix_timestamp() as u64,
3490        nonce: Some(nonce),
3491        reply_to_message_id: request.reply_to_message_id,
3492        edited_at: None,
3493    };
3494    broadcast(
3495        Some(session.connection_id),
3496        participant_connection_ids.clone(),
3497        |connection| {
3498            session.peer.send(
3499                connection,
3500                proto::ChannelMessageSent {
3501                    channel_id: channel_id.to_proto(),
3502                    message: Some(message.clone()),
3503                },
3504            )
3505        },
3506    );
3507    response.send(proto::SendChannelMessageResponse {
3508        message: Some(message),
3509    })?;
3510
3511    let pool = &*session.connection_pool().await;
3512    let non_participants =
3513        pool.channel_connection_ids(channel_id)
3514            .filter_map(|(connection_id, _)| {
3515                if participant_connection_ids.contains(&connection_id) {
3516                    None
3517                } else {
3518                    Some(connection_id)
3519                }
3520            });
3521    broadcast(None, non_participants, |peer_id| {
3522        session.peer.send(
3523            peer_id,
3524            proto::UpdateChannels {
3525                latest_channel_message_ids: vec![proto::ChannelMessageId {
3526                    channel_id: channel_id.to_proto(),
3527                    message_id: message_id.to_proto(),
3528                }],
3529                ..Default::default()
3530            },
3531        )
3532    });
3533    send_notifications(pool, &session.peer, notifications);
3534
3535    Ok(())
3536}
3537
3538/// Delete a channel message
3539async fn remove_channel_message(
3540    request: proto::RemoveChannelMessage,
3541    response: Response<proto::RemoveChannelMessage>,
3542    session: Session,
3543) -> Result<()> {
3544    let channel_id = ChannelId::from_proto(request.channel_id);
3545    let message_id = MessageId::from_proto(request.message_id);
3546    let (connection_ids, existing_notification_ids) = session
3547        .db()
3548        .await
3549        .remove_channel_message(channel_id, message_id, session.user_id())
3550        .await?;
3551
3552    broadcast(
3553        Some(session.connection_id),
3554        connection_ids,
3555        move |connection| {
3556            session.peer.send(connection, request.clone())?;
3557
3558            for notification_id in &existing_notification_ids {
3559                session.peer.send(
3560                    connection,
3561                    proto::DeleteNotification {
3562                        notification_id: (*notification_id).to_proto(),
3563                    },
3564                )?;
3565            }
3566
3567            Ok(())
3568        },
3569    );
3570    response.send(proto::Ack {})?;
3571    Ok(())
3572}
3573
3574async fn update_channel_message(
3575    request: proto::UpdateChannelMessage,
3576    response: Response<proto::UpdateChannelMessage>,
3577    session: Session,
3578) -> Result<()> {
3579    let channel_id = ChannelId::from_proto(request.channel_id);
3580    let message_id = MessageId::from_proto(request.message_id);
3581    let updated_at = OffsetDateTime::now_utc();
3582    let UpdatedChannelMessage {
3583        message_id,
3584        participant_connection_ids,
3585        notifications,
3586        reply_to_message_id,
3587        timestamp,
3588        deleted_mention_notification_ids,
3589        updated_mention_notifications,
3590    } = session
3591        .db()
3592        .await
3593        .update_channel_message(
3594            channel_id,
3595            message_id,
3596            session.user_id(),
3597            request.body.as_str(),
3598            &request.mentions,
3599            updated_at,
3600        )
3601        .await?;
3602
3603    let nonce = request
3604        .nonce
3605        .clone()
3606        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3607
3608    let message = proto::ChannelMessage {
3609        sender_id: session.user_id().to_proto(),
3610        id: message_id.to_proto(),
3611        body: request.body.clone(),
3612        mentions: request.mentions.clone(),
3613        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3614        nonce: Some(nonce),
3615        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3616        edited_at: Some(updated_at.unix_timestamp() as u64),
3617    };
3618
3619    response.send(proto::Ack {})?;
3620
3621    let pool = &*session.connection_pool().await;
3622    broadcast(
3623        Some(session.connection_id),
3624        participant_connection_ids,
3625        |connection| {
3626            session.peer.send(
3627                connection,
3628                proto::ChannelMessageUpdate {
3629                    channel_id: channel_id.to_proto(),
3630                    message: Some(message.clone()),
3631                },
3632            )?;
3633
3634            for notification_id in &deleted_mention_notification_ids {
3635                session.peer.send(
3636                    connection,
3637                    proto::DeleteNotification {
3638                        notification_id: (*notification_id).to_proto(),
3639                    },
3640                )?;
3641            }
3642
3643            for notification in &updated_mention_notifications {
3644                session.peer.send(
3645                    connection,
3646                    proto::UpdateNotification {
3647                        notification: Some(notification.clone()),
3648                    },
3649                )?;
3650            }
3651
3652            Ok(())
3653        },
3654    );
3655
3656    send_notifications(pool, &session.peer, notifications);
3657
3658    Ok(())
3659}
3660
3661/// Mark a channel message as read
3662async fn acknowledge_channel_message(
3663    request: proto::AckChannelMessage,
3664    session: Session,
3665) -> Result<()> {
3666    let channel_id = ChannelId::from_proto(request.channel_id);
3667    let message_id = MessageId::from_proto(request.message_id);
3668    let notifications = session
3669        .db()
3670        .await
3671        .observe_channel_message(channel_id, session.user_id(), message_id)
3672        .await?;
3673    send_notifications(
3674        &*session.connection_pool().await,
3675        &session.peer,
3676        notifications,
3677    );
3678    Ok(())
3679}
3680
3681/// Mark a buffer version as synced
3682async fn acknowledge_buffer_version(
3683    request: proto::AckBufferOperation,
3684    session: Session,
3685) -> Result<()> {
3686    let buffer_id = BufferId::from_proto(request.buffer_id);
3687    session
3688        .db()
3689        .await
3690        .observe_buffer_version(
3691            buffer_id,
3692            session.user_id(),
3693            request.epoch as i32,
3694            &request.version,
3695        )
3696        .await?;
3697    Ok(())
3698}
3699
3700async fn count_language_model_tokens(
3701    request: proto::CountLanguageModelTokens,
3702    response: Response<proto::CountLanguageModelTokens>,
3703    session: Session,
3704    config: &Config,
3705) -> Result<()> {
3706    authorize_access_to_legacy_llm_endpoints(&session).await?;
3707
3708    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3709        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3710        proto::Plan::Free | proto::Plan::ZedProTrial => {
3711            Box::new(FreeCountLanguageModelTokensRateLimit)
3712        }
3713    };
3714
3715    session
3716        .app_state
3717        .rate_limiter
3718        .check(&*rate_limit, session.user_id())
3719        .await?;
3720
3721    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3722        Some(proto::LanguageModelProvider::Google) => {
3723            let api_key = config
3724                .google_ai_api_key
3725                .as_ref()
3726                .context("no Google AI API key configured on the server")?;
3727            google_ai::count_tokens(
3728                session.http_client.as_ref(),
3729                google_ai::API_URL,
3730                api_key,
3731                serde_json::from_str(&request.request)?,
3732            )
3733            .await?
3734        }
3735        _ => return Err(anyhow!("unsupported provider"))?,
3736    };
3737
3738    response.send(proto::CountLanguageModelTokensResponse {
3739        token_count: result.total_tokens as u32,
3740    })?;
3741
3742    Ok(())
3743}
3744
3745struct ZedProCountLanguageModelTokensRateLimit;
3746
3747impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3748    fn capacity(&self) -> usize {
3749        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3750            .ok()
3751            .and_then(|v| v.parse().ok())
3752            .unwrap_or(600) // Picked arbitrarily
3753    }
3754
3755    fn refill_duration(&self) -> chrono::Duration {
3756        chrono::Duration::hours(1)
3757    }
3758
3759    fn db_name(&self) -> &'static str {
3760        "zed-pro:count-language-model-tokens"
3761    }
3762}
3763
3764struct FreeCountLanguageModelTokensRateLimit;
3765
3766impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3767    fn capacity(&self) -> usize {
3768        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3769            .ok()
3770            .and_then(|v| v.parse().ok())
3771            .unwrap_or(600 / 10) // Picked arbitrarily
3772    }
3773
3774    fn refill_duration(&self) -> chrono::Duration {
3775        chrono::Duration::hours(1)
3776    }
3777
3778    fn db_name(&self) -> &'static str {
3779        "free:count-language-model-tokens"
3780    }
3781}
3782
3783struct ZedProComputeEmbeddingsRateLimit;
3784
3785impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3786    fn capacity(&self) -> usize {
3787        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3788            .ok()
3789            .and_then(|v| v.parse().ok())
3790            .unwrap_or(5000) // Picked arbitrarily
3791    }
3792
3793    fn refill_duration(&self) -> chrono::Duration {
3794        chrono::Duration::hours(1)
3795    }
3796
3797    fn db_name(&self) -> &'static str {
3798        "zed-pro:compute-embeddings"
3799    }
3800}
3801
3802struct FreeComputeEmbeddingsRateLimit;
3803
3804impl RateLimit for FreeComputeEmbeddingsRateLimit {
3805    fn capacity(&self) -> usize {
3806        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3807            .ok()
3808            .and_then(|v| v.parse().ok())
3809            .unwrap_or(5000 / 10) // Picked arbitrarily
3810    }
3811
3812    fn refill_duration(&self) -> chrono::Duration {
3813        chrono::Duration::hours(1)
3814    }
3815
3816    fn db_name(&self) -> &'static str {
3817        "free:compute-embeddings"
3818    }
3819}
3820
3821async fn compute_embeddings(
3822    request: proto::ComputeEmbeddings,
3823    response: Response<proto::ComputeEmbeddings>,
3824    session: Session,
3825    api_key: Option<Arc<str>>,
3826) -> Result<()> {
3827    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3828    authorize_access_to_legacy_llm_endpoints(&session).await?;
3829
3830    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3831        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3832        proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
3833    };
3834
3835    session
3836        .app_state
3837        .rate_limiter
3838        .check(&*rate_limit, session.user_id())
3839        .await?;
3840
3841    let embeddings = match request.model.as_str() {
3842        "openai/text-embedding-3-small" => {
3843            open_ai::embed(
3844                session.http_client.as_ref(),
3845                OPEN_AI_API_URL,
3846                &api_key,
3847                OpenAiEmbeddingModel::TextEmbedding3Small,
3848                request.texts.iter().map(|text| text.as_str()),
3849            )
3850            .await?
3851        }
3852        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3853    };
3854
3855    let embeddings = request
3856        .texts
3857        .iter()
3858        .map(|text| {
3859            let mut hasher = sha2::Sha256::new();
3860            hasher.update(text.as_bytes());
3861            let result = hasher.finalize();
3862            result.to_vec()
3863        })
3864        .zip(
3865            embeddings
3866                .data
3867                .into_iter()
3868                .map(|embedding| embedding.embedding),
3869        )
3870        .collect::<HashMap<_, _>>();
3871
3872    let db = session.db().await;
3873    db.save_embeddings(&request.model, &embeddings)
3874        .await
3875        .context("failed to save embeddings")
3876        .trace_err();
3877
3878    response.send(proto::ComputeEmbeddingsResponse {
3879        embeddings: embeddings
3880            .into_iter()
3881            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3882            .collect(),
3883    })?;
3884    Ok(())
3885}
3886
3887async fn get_cached_embeddings(
3888    request: proto::GetCachedEmbeddings,
3889    response: Response<proto::GetCachedEmbeddings>,
3890    session: Session,
3891) -> Result<()> {
3892    authorize_access_to_legacy_llm_endpoints(&session).await?;
3893
3894    let db = session.db().await;
3895    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3896
3897    response.send(proto::GetCachedEmbeddingsResponse {
3898        embeddings: embeddings
3899            .into_iter()
3900            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3901            .collect(),
3902    })?;
3903    Ok(())
3904}
3905
3906/// This is leftover from before the LLM service.
3907///
3908/// The endpoints protected by this check will be moved there eventually.
3909async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3910    if session.is_staff() {
3911        Ok(())
3912    } else {
3913        Err(anyhow!("permission denied"))?
3914    }
3915}
3916
3917/// Get a Supermaven API key for the user
3918async fn get_supermaven_api_key(
3919    _request: proto::GetSupermavenApiKey,
3920    response: Response<proto::GetSupermavenApiKey>,
3921    session: Session,
3922) -> Result<()> {
3923    let user_id: String = session.user_id().to_string();
3924    if !session.is_staff() {
3925        return Err(anyhow!("supermaven not enabled for this account"))?;
3926    }
3927
3928    let email = session
3929        .email()
3930        .ok_or_else(|| anyhow!("user must have an email"))?;
3931
3932    let supermaven_admin_api = session
3933        .supermaven_client
3934        .as_ref()
3935        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3936
3937    let result = supermaven_admin_api
3938        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3939        .await?;
3940
3941    response.send(proto::GetSupermavenApiKeyResponse {
3942        api_key: result.api_key,
3943    })?;
3944
3945    Ok(())
3946}
3947
3948/// Start receiving chat updates for a channel
3949async fn join_channel_chat(
3950    request: proto::JoinChannelChat,
3951    response: Response<proto::JoinChannelChat>,
3952    session: Session,
3953) -> Result<()> {
3954    let channel_id = ChannelId::from_proto(request.channel_id);
3955
3956    let db = session.db().await;
3957    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3958        .await?;
3959    let messages = db
3960        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3961        .await?;
3962    response.send(proto::JoinChannelChatResponse {
3963        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3964        messages,
3965    })?;
3966    Ok(())
3967}
3968
3969/// Stop receiving chat updates for a channel
3970async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3971    let channel_id = ChannelId::from_proto(request.channel_id);
3972    session
3973        .db()
3974        .await
3975        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3976        .await?;
3977    Ok(())
3978}
3979
3980/// Retrieve the chat history for a channel
3981async fn get_channel_messages(
3982    request: proto::GetChannelMessages,
3983    response: Response<proto::GetChannelMessages>,
3984    session: Session,
3985) -> Result<()> {
3986    let channel_id = ChannelId::from_proto(request.channel_id);
3987    let messages = session
3988        .db()
3989        .await
3990        .get_channel_messages(
3991            channel_id,
3992            session.user_id(),
3993            MESSAGE_COUNT_PER_PAGE,
3994            Some(MessageId::from_proto(request.before_message_id)),
3995        )
3996        .await?;
3997    response.send(proto::GetChannelMessagesResponse {
3998        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3999        messages,
4000    })?;
4001    Ok(())
4002}
4003
4004/// Retrieve specific chat messages
4005async fn get_channel_messages_by_id(
4006    request: proto::GetChannelMessagesById,
4007    response: Response<proto::GetChannelMessagesById>,
4008    session: Session,
4009) -> Result<()> {
4010    let message_ids = request
4011        .message_ids
4012        .iter()
4013        .map(|id| MessageId::from_proto(*id))
4014        .collect::<Vec<_>>();
4015    let messages = session
4016        .db()
4017        .await
4018        .get_channel_messages_by_id(session.user_id(), &message_ids)
4019        .await?;
4020    response.send(proto::GetChannelMessagesResponse {
4021        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4022        messages,
4023    })?;
4024    Ok(())
4025}
4026
4027/// Retrieve the current users notifications
4028async fn get_notifications(
4029    request: proto::GetNotifications,
4030    response: Response<proto::GetNotifications>,
4031    session: Session,
4032) -> Result<()> {
4033    let notifications = session
4034        .db()
4035        .await
4036        .get_notifications(
4037            session.user_id(),
4038            NOTIFICATION_COUNT_PER_PAGE,
4039            request.before_id.map(db::NotificationId::from_proto),
4040        )
4041        .await?;
4042    response.send(proto::GetNotificationsResponse {
4043        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4044        notifications,
4045    })?;
4046    Ok(())
4047}
4048
4049/// Mark notifications as read
4050async fn mark_notification_as_read(
4051    request: proto::MarkNotificationRead,
4052    response: Response<proto::MarkNotificationRead>,
4053    session: Session,
4054) -> Result<()> {
4055    let database = &session.db().await;
4056    let notifications = database
4057        .mark_notification_as_read_by_id(
4058            session.user_id(),
4059            NotificationId::from_proto(request.notification_id),
4060        )
4061        .await?;
4062    send_notifications(
4063        &*session.connection_pool().await,
4064        &session.peer,
4065        notifications,
4066    );
4067    response.send(proto::Ack {})?;
4068    Ok(())
4069}
4070
4071/// Get the current users information
4072async fn get_private_user_info(
4073    _request: proto::GetPrivateUserInfo,
4074    response: Response<proto::GetPrivateUserInfo>,
4075    session: Session,
4076) -> Result<()> {
4077    let db = session.db().await;
4078
4079    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4080    let user = db
4081        .get_user_by_id(session.user_id())
4082        .await?
4083        .ok_or_else(|| anyhow!("user not found"))?;
4084    let flags = db.get_user_flags(session.user_id()).await?;
4085
4086    response.send(proto::GetPrivateUserInfoResponse {
4087        metrics_id,
4088        staff: user.admin,
4089        flags,
4090        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4091    })?;
4092    Ok(())
4093}
4094
4095/// Accept the terms of service (tos) on behalf of the current user
4096async fn accept_terms_of_service(
4097    _request: proto::AcceptTermsOfService,
4098    response: Response<proto::AcceptTermsOfService>,
4099    session: Session,
4100) -> Result<()> {
4101    let db = session.db().await;
4102
4103    let accepted_tos_at = Utc::now();
4104    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4105        .await?;
4106
4107    response.send(proto::AcceptTermsOfServiceResponse {
4108        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4109    })?;
4110    Ok(())
4111}
4112
4113/// The minimum account age an account must have in order to use the LLM service.
4114pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4115
4116async fn get_llm_api_token(
4117    _request: proto::GetLlmToken,
4118    response: Response<proto::GetLlmToken>,
4119    session: Session,
4120) -> Result<()> {
4121    let db = session.db().await;
4122
4123    let flags = db.get_user_flags(session.user_id()).await?;
4124    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4125
4126    if !session.is_staff() && !has_language_models_feature_flag {
4127        Err(anyhow!("permission denied"))?
4128    }
4129
4130    let user_id = session.user_id();
4131    let user = db
4132        .get_user_by_id(user_id)
4133        .await?
4134        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4135
4136    if user.accepted_tos_at.is_none() {
4137        Err(anyhow!("terms of service not accepted"))?
4138    }
4139
4140    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4141    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4142    let billing_preferences = db.get_billing_preferences(user.id).await?;
4143
4144    let token = LlmTokenClaims::create(
4145        &user,
4146        session.is_staff(),
4147        billing_preferences,
4148        &flags,
4149        has_legacy_llm_subscription,
4150        billing_subscription,
4151        session.system_id.clone(),
4152        &session.app_state.config,
4153    )?;
4154    response.send(proto::GetLlmTokenResponse { token })?;
4155    Ok(())
4156}
4157
4158fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4159    let message = match message {
4160        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4161        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4162        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4163        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4164        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4165            code: frame.code.into(),
4166            reason: frame.reason.as_str().to_owned().into(),
4167        })),
4168        // We should never receive a frame while reading the message, according
4169        // to the `tungstenite` maintainers:
4170        //
4171        // > It cannot occur when you read messages from the WebSocket, but it
4172        // > can be used when you want to send the raw frames (e.g. you want to
4173        // > send the frames to the WebSocket without composing the full message first).
4174        // >
4175        // > — https://github.com/snapview/tungstenite-rs/issues/268
4176        TungsteniteMessage::Frame(_) => {
4177            bail!("received an unexpected frame while reading the message")
4178        }
4179    };
4180
4181    Ok(message)
4182}
4183
4184fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4185    match message {
4186        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4187        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4188        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4189        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4190        AxumMessage::Close(frame) => {
4191            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4192                code: frame.code.into(),
4193                reason: frame.reason.as_ref().into(),
4194            }))
4195        }
4196    }
4197}
4198
4199fn notify_membership_updated(
4200    connection_pool: &mut ConnectionPool,
4201    result: MembershipUpdated,
4202    user_id: UserId,
4203    peer: &Peer,
4204) {
4205    for membership in &result.new_channels.channel_memberships {
4206        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4207    }
4208    for channel_id in &result.removed_channels {
4209        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4210    }
4211
4212    let user_channels_update = proto::UpdateUserChannels {
4213        channel_memberships: result
4214            .new_channels
4215            .channel_memberships
4216            .iter()
4217            .map(|cm| proto::ChannelMembership {
4218                channel_id: cm.channel_id.to_proto(),
4219                role: cm.role.into(),
4220            })
4221            .collect(),
4222        ..Default::default()
4223    };
4224
4225    let mut update = build_channels_update(result.new_channels);
4226    update.delete_channels = result
4227        .removed_channels
4228        .into_iter()
4229        .map(|id| id.to_proto())
4230        .collect();
4231    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4232
4233    for connection_id in connection_pool.user_connection_ids(user_id) {
4234        peer.send(connection_id, user_channels_update.clone())
4235            .trace_err();
4236        peer.send(connection_id, update.clone()).trace_err();
4237    }
4238}
4239
4240fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4241    proto::UpdateUserChannels {
4242        channel_memberships: channels
4243            .channel_memberships
4244            .iter()
4245            .map(|m| proto::ChannelMembership {
4246                channel_id: m.channel_id.to_proto(),
4247                role: m.role.into(),
4248            })
4249            .collect(),
4250        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4251        observed_channel_message_id: channels.observed_channel_messages.clone(),
4252    }
4253}
4254
4255fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4256    let mut update = proto::UpdateChannels::default();
4257
4258    for channel in channels.channels {
4259        update.channels.push(channel.to_proto());
4260    }
4261
4262    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4263    update.latest_channel_message_ids = channels.latest_channel_messages;
4264
4265    for (channel_id, participants) in channels.channel_participants {
4266        update
4267            .channel_participants
4268            .push(proto::ChannelParticipants {
4269                channel_id: channel_id.to_proto(),
4270                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4271            });
4272    }
4273
4274    for channel in channels.invited_channels {
4275        update.channel_invitations.push(channel.to_proto());
4276    }
4277
4278    update
4279}
4280
4281fn build_initial_contacts_update(
4282    contacts: Vec<db::Contact>,
4283    pool: &ConnectionPool,
4284) -> proto::UpdateContacts {
4285    let mut update = proto::UpdateContacts::default();
4286
4287    for contact in contacts {
4288        match contact {
4289            db::Contact::Accepted { user_id, busy } => {
4290                update.contacts.push(contact_for_user(user_id, busy, pool));
4291            }
4292            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4293            db::Contact::Incoming { user_id } => {
4294                update
4295                    .incoming_requests
4296                    .push(proto::IncomingContactRequest {
4297                        requester_id: user_id.to_proto(),
4298                    })
4299            }
4300        }
4301    }
4302
4303    update
4304}
4305
4306fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4307    proto::Contact {
4308        user_id: user_id.to_proto(),
4309        online: pool.is_user_online(user_id),
4310        busy,
4311    }
4312}
4313
4314fn room_updated(room: &proto::Room, peer: &Peer) {
4315    broadcast(
4316        None,
4317        room.participants
4318            .iter()
4319            .filter_map(|participant| Some(participant.peer_id?.into())),
4320        |peer_id| {
4321            peer.send(
4322                peer_id,
4323                proto::RoomUpdated {
4324                    room: Some(room.clone()),
4325                },
4326            )
4327        },
4328    );
4329}
4330
4331fn channel_updated(
4332    channel: &db::channel::Model,
4333    room: &proto::Room,
4334    peer: &Peer,
4335    pool: &ConnectionPool,
4336) {
4337    let participants = room
4338        .participants
4339        .iter()
4340        .map(|p| p.user_id)
4341        .collect::<Vec<_>>();
4342
4343    broadcast(
4344        None,
4345        pool.channel_connection_ids(channel.root_id())
4346            .filter_map(|(channel_id, role)| {
4347                role.can_see_channel(channel.visibility)
4348                    .then_some(channel_id)
4349            }),
4350        |peer_id| {
4351            peer.send(
4352                peer_id,
4353                proto::UpdateChannels {
4354                    channel_participants: vec![proto::ChannelParticipants {
4355                        channel_id: channel.id.to_proto(),
4356                        participant_user_ids: participants.clone(),
4357                    }],
4358                    ..Default::default()
4359                },
4360            )
4361        },
4362    );
4363}
4364
4365async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4366    let db = session.db().await;
4367
4368    let contacts = db.get_contacts(user_id).await?;
4369    let busy = db.is_user_busy(user_id).await?;
4370
4371    let pool = session.connection_pool().await;
4372    let updated_contact = contact_for_user(user_id, busy, &pool);
4373    for contact in contacts {
4374        if let db::Contact::Accepted {
4375            user_id: contact_user_id,
4376            ..
4377        } = contact
4378        {
4379            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4380                session
4381                    .peer
4382                    .send(
4383                        contact_conn_id,
4384                        proto::UpdateContacts {
4385                            contacts: vec![updated_contact.clone()],
4386                            remove_contacts: Default::default(),
4387                            incoming_requests: Default::default(),
4388                            remove_incoming_requests: Default::default(),
4389                            outgoing_requests: Default::default(),
4390                            remove_outgoing_requests: Default::default(),
4391                        },
4392                    )
4393                    .trace_err();
4394            }
4395        }
4396    }
4397    Ok(())
4398}
4399
4400async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4401    let mut contacts_to_update = HashSet::default();
4402
4403    let room_id;
4404    let canceled_calls_to_user_ids;
4405    let livekit_room;
4406    let delete_livekit_room;
4407    let room;
4408    let channel;
4409
4410    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4411        contacts_to_update.insert(session.user_id());
4412
4413        for project in left_room.left_projects.values() {
4414            project_left(project, session);
4415        }
4416
4417        room_id = RoomId::from_proto(left_room.room.id);
4418        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4419        livekit_room = mem::take(&mut left_room.room.livekit_room);
4420        delete_livekit_room = left_room.deleted;
4421        room = mem::take(&mut left_room.room);
4422        channel = mem::take(&mut left_room.channel);
4423
4424        room_updated(&room, &session.peer);
4425    } else {
4426        return Ok(());
4427    }
4428
4429    if let Some(channel) = channel {
4430        channel_updated(
4431            &channel,
4432            &room,
4433            &session.peer,
4434            &*session.connection_pool().await,
4435        );
4436    }
4437
4438    {
4439        let pool = session.connection_pool().await;
4440        for canceled_user_id in canceled_calls_to_user_ids {
4441            for connection_id in pool.user_connection_ids(canceled_user_id) {
4442                session
4443                    .peer
4444                    .send(
4445                        connection_id,
4446                        proto::CallCanceled {
4447                            room_id: room_id.to_proto(),
4448                        },
4449                    )
4450                    .trace_err();
4451            }
4452            contacts_to_update.insert(canceled_user_id);
4453        }
4454    }
4455
4456    for contact_user_id in contacts_to_update {
4457        update_user_contacts(contact_user_id, session).await?;
4458    }
4459
4460    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4461        live_kit
4462            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4463            .await
4464            .trace_err();
4465
4466        if delete_livekit_room {
4467            live_kit.delete_room(livekit_room).await.trace_err();
4468        }
4469    }
4470
4471    Ok(())
4472}
4473
4474async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4475    let left_channel_buffers = session
4476        .db()
4477        .await
4478        .leave_channel_buffers(session.connection_id)
4479        .await?;
4480
4481    for left_buffer in left_channel_buffers {
4482        channel_buffer_updated(
4483            session.connection_id,
4484            left_buffer.connections,
4485            &proto::UpdateChannelBufferCollaborators {
4486                channel_id: left_buffer.channel_id.to_proto(),
4487                collaborators: left_buffer.collaborators,
4488            },
4489            &session.peer,
4490        );
4491    }
4492
4493    Ok(())
4494}
4495
4496fn project_left(project: &db::LeftProject, session: &Session) {
4497    for connection_id in &project.connection_ids {
4498        if project.should_unshare {
4499            session
4500                .peer
4501                .send(
4502                    *connection_id,
4503                    proto::UnshareProject {
4504                        project_id: project.id.to_proto(),
4505                    },
4506                )
4507                .trace_err();
4508        } else {
4509            session
4510                .peer
4511                .send(
4512                    *connection_id,
4513                    proto::RemoveProjectCollaborator {
4514                        project_id: project.id.to_proto(),
4515                        peer_id: Some(session.connection_id.into()),
4516                    },
4517                )
4518                .trace_err();
4519        }
4520    }
4521}
4522
4523pub trait ResultExt {
4524    type Ok;
4525
4526    fn trace_err(self) -> Option<Self::Ok>;
4527}
4528
4529impl<T, E> ResultExt for Result<T, E>
4530where
4531    E: std::fmt::Debug,
4532{
4533    type Ok = T;
4534
4535    #[track_caller]
4536    fn trace_err(self) -> Option<T> {
4537        match self {
4538            Ok(value) => Some(value),
4539            Err(error) => {
4540                tracing::error!("{:?}", error);
4541                None
4542            }
4543        }
4544    }
4545}