rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Config, Error, RateLimit, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use http_client::HttpClient;
  37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
  38use reqwest_client::ReqwestClient;
  39use rpc::proto::split_repository_update;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  45    stream::FuturesUnordered,
  46};
  47use prometheus::{IntGauge, register_int_gauge};
  48use rpc::{
  49    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        Arc, OnceLock,
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{MutexGuard, Semaphore, watch};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    Instrument,
  76    field::{self},
  77    info_span, instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 311            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 312            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 313            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 320            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(
 325                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 326            )
 327            .add_request_handler(
 328                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 329            )
 330            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 331            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 332            .add_request_handler(
 333                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 334            )
 335            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 358            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 359            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 361            .add_message_handler(create_buffer_for_peer)
 362            .add_request_handler(update_buffer)
 363            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 368            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 369            .add_request_handler(get_users)
 370            .add_request_handler(fuzzy_search_users)
 371            .add_request_handler(request_contact)
 372            .add_request_handler(remove_contact)
 373            .add_request_handler(respond_to_contact_request)
 374            .add_message_handler(subscribe_to_channels)
 375            .add_request_handler(create_channel)
 376            .add_request_handler(delete_channel)
 377            .add_request_handler(invite_channel_member)
 378            .add_request_handler(remove_channel_member)
 379            .add_request_handler(set_channel_member_role)
 380            .add_request_handler(set_channel_visibility)
 381            .add_request_handler(rename_channel)
 382            .add_request_handler(join_channel_buffer)
 383            .add_request_handler(leave_channel_buffer)
 384            .add_message_handler(update_channel_buffer)
 385            .add_request_handler(rejoin_channel_buffers)
 386            .add_request_handler(get_channel_members)
 387            .add_request_handler(respond_to_channel_invite)
 388            .add_request_handler(join_channel)
 389            .add_request_handler(join_channel_chat)
 390            .add_message_handler(leave_channel_chat)
 391            .add_request_handler(send_channel_message)
 392            .add_request_handler(remove_channel_message)
 393            .add_request_handler(update_channel_message)
 394            .add_request_handler(get_channel_messages)
 395            .add_request_handler(get_channel_messages_by_id)
 396            .add_request_handler(get_notifications)
 397            .add_request_handler(mark_notification_as_read)
 398            .add_request_handler(move_channel)
 399            .add_request_handler(follow)
 400            .add_message_handler(unfollow)
 401            .add_message_handler(update_followers)
 402            .add_request_handler(get_private_user_info)
 403            .add_request_handler(get_llm_api_token)
 404            .add_request_handler(accept_terms_of_service)
 405            .add_message_handler(acknowledge_channel_message)
 406            .add_message_handler(acknowledge_buffer_version)
 407            .add_request_handler(get_supermaven_api_key)
 408            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 409            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 410            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 411            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 413            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 416            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 417            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 418            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 419            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 420            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 421            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 422            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 423            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 424            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 425            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 426            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 427            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 429            .add_message_handler(update_context)
 430            .add_request_handler({
 431                let app_state = app_state.clone();
 432                move |request, response, session| {
 433                    let app_state = app_state.clone();
 434                    async move {
 435                        count_language_model_tokens(request, response, session, &app_state.config)
 436                            .await
 437                    }
 438                }
 439            })
 440            .add_request_handler(get_cached_embeddings)
 441            .add_request_handler({
 442                let app_state = app_state.clone();
 443                move |request, response, session| {
 444                    compute_embeddings(
 445                        request,
 446                        response,
 447                        session,
 448                        app_state.config.openai_api_key.clone(),
 449                    )
 450                }
 451            });
 452
 453        Arc::new(server)
 454    }
 455
 456    pub async fn start(&self) -> Result<()> {
 457        let server_id = *self.id.lock();
 458        let app_state = self.app_state.clone();
 459        let peer = self.peer.clone();
 460        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 461        let pool = self.connection_pool.clone();
 462        let livekit_client = self.app_state.livekit_client.clone();
 463
 464        let span = info_span!("start server");
 465        self.app_state.executor.spawn_detached(
 466            async move {
 467                tracing::info!("waiting for cleanup timeout");
 468                timeout.await;
 469                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 470                if let Some((room_ids, channel_ids)) = app_state
 471                    .db
 472                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 473                    .await
 474                    .trace_err()
 475                {
 476                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 477                    tracing::info!(
 478                        stale_channel_buffer_count = channel_ids.len(),
 479                        "retrieved stale channel buffers"
 480                    );
 481
 482                    for channel_id in channel_ids {
 483                        if let Some(refreshed_channel_buffer) = app_state
 484                            .db
 485                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 486                            .await
 487                            .trace_err()
 488                        {
 489                            for connection_id in refreshed_channel_buffer.connection_ids {
 490                                peer.send(
 491                                    connection_id,
 492                                    proto::UpdateChannelBufferCollaborators {
 493                                        channel_id: channel_id.to_proto(),
 494                                        collaborators: refreshed_channel_buffer
 495                                            .collaborators
 496                                            .clone(),
 497                                    },
 498                                )
 499                                .trace_err();
 500                            }
 501                        }
 502                    }
 503
 504                    for room_id in room_ids {
 505                        let mut contacts_to_update = HashSet::default();
 506                        let mut canceled_calls_to_user_ids = Vec::new();
 507                        let mut livekit_room = String::new();
 508                        let mut delete_livekit_room = false;
 509
 510                        if let Some(mut refreshed_room) = app_state
 511                            .db
 512                            .clear_stale_room_participants(room_id, server_id)
 513                            .await
 514                            .trace_err()
 515                        {
 516                            tracing::info!(
 517                                room_id = room_id.0,
 518                                new_participant_count = refreshed_room.room.participants.len(),
 519                                "refreshed room"
 520                            );
 521                            room_updated(&refreshed_room.room, &peer);
 522                            if let Some(channel) = refreshed_room.channel.as_ref() {
 523                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 524                            }
 525                            contacts_to_update
 526                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 527                            contacts_to_update
 528                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 529                            canceled_calls_to_user_ids =
 530                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 531                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 532                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 533                        }
 534
 535                        {
 536                            let pool = pool.lock();
 537                            for canceled_user_id in canceled_calls_to_user_ids {
 538                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 539                                    peer.send(
 540                                        connection_id,
 541                                        proto::CallCanceled {
 542                                            room_id: room_id.to_proto(),
 543                                        },
 544                                    )
 545                                    .trace_err();
 546                                }
 547                            }
 548                        }
 549
 550                        for user_id in contacts_to_update {
 551                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 552                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 553                            if let Some((busy, contacts)) = busy.zip(contacts) {
 554                                let pool = pool.lock();
 555                                let updated_contact = contact_for_user(user_id, busy, &pool);
 556                                for contact in contacts {
 557                                    if let db::Contact::Accepted {
 558                                        user_id: contact_user_id,
 559                                        ..
 560                                    } = contact
 561                                    {
 562                                        for contact_conn_id in
 563                                            pool.user_connection_ids(contact_user_id)
 564                                        {
 565                                            peer.send(
 566                                                contact_conn_id,
 567                                                proto::UpdateContacts {
 568                                                    contacts: vec![updated_contact.clone()],
 569                                                    remove_contacts: Default::default(),
 570                                                    incoming_requests: Default::default(),
 571                                                    remove_incoming_requests: Default::default(),
 572                                                    outgoing_requests: Default::default(),
 573                                                    remove_outgoing_requests: Default::default(),
 574                                                },
 575                                            )
 576                                            .trace_err();
 577                                        }
 578                                    }
 579                                }
 580                            }
 581                        }
 582
 583                        if let Some(live_kit) = livekit_client.as_ref() {
 584                            if delete_livekit_room {
 585                                live_kit.delete_room(livekit_room).await.trace_err();
 586                            }
 587                        }
 588                    }
 589                }
 590
 591                app_state
 592                    .db
 593                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 594                    .await
 595                    .trace_err();
 596            }
 597            .instrument(span),
 598        );
 599        Ok(())
 600    }
 601
 602    pub fn teardown(&self) {
 603        self.peer.teardown();
 604        self.connection_pool.lock().reset();
 605        let _ = self.teardown.send(true);
 606    }
 607
 608    #[cfg(test)]
 609    pub fn reset(&self, id: ServerId) {
 610        self.teardown();
 611        *self.id.lock() = id;
 612        self.peer.reset(id.0 as u32);
 613        let _ = self.teardown.send(false);
 614    }
 615
 616    #[cfg(test)]
 617    pub fn id(&self) -> ServerId {
 618        *self.id.lock()
 619    }
 620
 621    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 622    where
 623        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 624        Fut: 'static + Send + Future<Output = Result<()>>,
 625        M: EnvelopedMessage,
 626    {
 627        let prev_handler = self.handlers.insert(
 628            TypeId::of::<M>(),
 629            Box::new(move |envelope, session| {
 630                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 631                let received_at = envelope.received_at;
 632                tracing::info!("message received");
 633                let start_time = Instant::now();
 634                let future = (handler)(*envelope, session);
 635                async move {
 636                    let result = future.await;
 637                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 638                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 639                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 640                    let payload_type = M::NAME;
 641
 642                    match result {
 643                        Err(error) => {
 644                            tracing::error!(
 645                                ?error,
 646                                total_duration_ms,
 647                                processing_duration_ms,
 648                                queue_duration_ms,
 649                                payload_type,
 650                                "error handling message"
 651                            )
 652                        }
 653                        Ok(()) => tracing::info!(
 654                            total_duration_ms,
 655                            processing_duration_ms,
 656                            queue_duration_ms,
 657                            "finished handling message"
 658                        ),
 659                    }
 660                }
 661                .boxed()
 662            }),
 663        );
 664        if prev_handler.is_some() {
 665            panic!("registered a handler for the same message twice");
 666        }
 667        self
 668    }
 669
 670    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 671    where
 672        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 673        Fut: 'static + Send + Future<Output = Result<()>>,
 674        M: EnvelopedMessage,
 675    {
 676        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 677        self
 678    }
 679
 680    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 681    where
 682        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 683        Fut: Send + Future<Output = Result<()>>,
 684        M: RequestMessage,
 685    {
 686        let handler = Arc::new(handler);
 687        self.add_handler(move |envelope, session| {
 688            let receipt = envelope.receipt();
 689            let handler = handler.clone();
 690            async move {
 691                let peer = session.peer.clone();
 692                let responded = Arc::new(AtomicBool::default());
 693                let response = Response {
 694                    peer: peer.clone(),
 695                    responded: responded.clone(),
 696                    receipt,
 697                };
 698                match (handler)(envelope.payload, response, session).await {
 699                    Ok(()) => {
 700                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 701                            Ok(())
 702                        } else {
 703                            Err(anyhow!("handler did not send a response"))?
 704                        }
 705                    }
 706                    Err(error) => {
 707                        let proto_err = match &error {
 708                            Error::Internal(err) => err.to_proto(),
 709                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 710                        };
 711                        peer.respond_with_error(receipt, proto_err)?;
 712                        Err(error)
 713                    }
 714                }
 715            }
 716        })
 717    }
 718
 719    pub fn handle_connection(
 720        self: &Arc<Self>,
 721        connection: Connection,
 722        address: String,
 723        principal: Principal,
 724        zed_version: ZedVersion,
 725        geoip_country_code: Option<String>,
 726        system_id: Option<String>,
 727        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 728        executor: Executor,
 729    ) -> impl Future<Output = ()> + use<> {
 730        let this = self.clone();
 731        let span = info_span!("handle connection", %address,
 732            connection_id=field::Empty,
 733            user_id=field::Empty,
 734            login=field::Empty,
 735            impersonator=field::Empty,
 736            geoip_country_code=field::Empty
 737        );
 738        principal.update_span(&span);
 739        if let Some(country_code) = geoip_country_code.as_ref() {
 740            span.record("geoip_country_code", country_code);
 741        }
 742
 743        let mut teardown = self.teardown.subscribe();
 744        async move {
 745            if *teardown.borrow() {
 746                tracing::error!("server is tearing down");
 747                return
 748            }
 749            let (connection_id, handle_io, mut incoming_rx) = this
 750                .peer
 751                .add_connection(connection, {
 752                    let executor = executor.clone();
 753                    move |duration| executor.sleep(duration)
 754                });
 755            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 756
 757            tracing::info!("connection opened");
 758
 759            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 760            let http_client = match ReqwestClient::user_agent(&user_agent) {
 761                Ok(http_client) => Arc::new(http_client),
 762                Err(error) => {
 763                    tracing::error!(?error, "failed to create HTTP client");
 764                    return;
 765                }
 766            };
 767
 768            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 769                    supermaven_admin_api_key.to_string(),
 770                    http_client.clone(),
 771                )));
 772
 773            let session = Session {
 774                principal: principal.clone(),
 775                connection_id,
 776                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 777                peer: this.peer.clone(),
 778                connection_pool: this.connection_pool.clone(),
 779                app_state: this.app_state.clone(),
 780                http_client,
 781                geoip_country_code,
 782                system_id,
 783                _executor: executor.clone(),
 784                supermaven_client,
 785            };
 786
 787            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 788                tracing::error!(?error, "failed to send initial client update");
 789                return;
 790            }
 791
 792            let handle_io = handle_io.fuse();
 793            futures::pin_mut!(handle_io);
 794
 795            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 796            // This prevents deadlocks when e.g., client A performs a request to client B and
 797            // client B performs a request to client A. If both clients stop processing further
 798            // messages until their respective request completes, they won't have a chance to
 799            // respond to the other client's request and cause a deadlock.
 800            //
 801            // This arrangement ensures we will attempt to process earlier messages first, but fall
 802            // back to processing messages arrived later in the spirit of making progress.
 803            let mut foreground_message_handlers = FuturesUnordered::new();
 804            let concurrent_handlers = Arc::new(Semaphore::new(256));
 805            loop {
 806                let next_message = async {
 807                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 808                    let message = incoming_rx.next().await;
 809                    (permit, message)
 810                }.fuse();
 811                futures::pin_mut!(next_message);
 812                futures::select_biased! {
 813                    _ = teardown.changed().fuse() => return,
 814                    result = handle_io => {
 815                        if let Err(error) = result {
 816                            tracing::error!(?error, "error handling I/O");
 817                        }
 818                        break;
 819                    }
 820                    _ = foreground_message_handlers.next() => {}
 821                    next_message = next_message => {
 822                        let (permit, message) = next_message;
 823                        if let Some(message) = message {
 824                            let type_name = message.payload_type_name();
 825                            // note: we copy all the fields from the parent span so we can query them in the logs.
 826                            // (https://github.com/tokio-rs/tracing/issues/2670).
 827                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 828                                user_id=field::Empty,
 829                                login=field::Empty,
 830                                impersonator=field::Empty,
 831                            );
 832                            principal.update_span(&span);
 833                            let span_enter = span.enter();
 834                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 835                                let is_background = message.is_background();
 836                                let handle_message = (handler)(message, session.clone());
 837                                drop(span_enter);
 838
 839                                let handle_message = async move {
 840                                    handle_message.await;
 841                                    drop(permit);
 842                                }.instrument(span);
 843                                if is_background {
 844                                    executor.spawn_detached(handle_message);
 845                                } else {
 846                                    foreground_message_handlers.push(handle_message);
 847                                }
 848                            } else {
 849                                tracing::error!("no message handler");
 850                            }
 851                        } else {
 852                            tracing::info!("connection closed");
 853                            break;
 854                        }
 855                    }
 856                }
 857            }
 858
 859            drop(foreground_message_handlers);
 860            tracing::info!("signing out");
 861            if let Err(error) = connection_lost(session, teardown, executor).await {
 862                tracing::error!(?error, "error signing out");
 863            }
 864
 865        }.instrument(span)
 866    }
 867
 868    async fn send_initial_client_update(
 869        &self,
 870        connection_id: ConnectionId,
 871        principal: &Principal,
 872        zed_version: ZedVersion,
 873        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 874        session: &Session,
 875    ) -> Result<()> {
 876        self.peer.send(
 877            connection_id,
 878            proto::Hello {
 879                peer_id: Some(connection_id.into()),
 880            },
 881        )?;
 882        tracing::info!("sent hello message");
 883        if let Some(send_connection_id) = send_connection_id.take() {
 884            let _ = send_connection_id.send(connection_id);
 885        }
 886
 887        match principal {
 888            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 889                if !user.connected_once {
 890                    self.peer.send(connection_id, proto::ShowContacts {})?;
 891                    self.app_state
 892                        .db
 893                        .set_user_connected_once(user.id, true)
 894                        .await?;
 895                }
 896
 897                update_user_plan(user.id, session).await?;
 898
 899                let contacts = self.app_state.db.get_contacts(user.id).await?;
 900
 901                {
 902                    let mut pool = self.connection_pool.lock();
 903                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 904                    self.peer.send(
 905                        connection_id,
 906                        build_initial_contacts_update(contacts, &pool),
 907                    )?;
 908                }
 909
 910                if should_auto_subscribe_to_channels(zed_version) {
 911                    subscribe_user_to_channels(user.id, session).await?;
 912                }
 913
 914                if let Some(incoming_call) =
 915                    self.app_state.db.incoming_call_for_user(user.id).await?
 916                {
 917                    self.peer.send(connection_id, incoming_call)?;
 918                }
 919
 920                update_user_contacts(user.id, session).await?;
 921            }
 922        }
 923
 924        Ok(())
 925    }
 926
 927    pub async fn invite_code_redeemed(
 928        self: &Arc<Self>,
 929        inviter_id: UserId,
 930        invitee_id: UserId,
 931    ) -> Result<()> {
 932        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 933            if let Some(code) = &user.invite_code {
 934                let pool = self.connection_pool.lock();
 935                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 936                for connection_id in pool.user_connection_ids(inviter_id) {
 937                    self.peer.send(
 938                        connection_id,
 939                        proto::UpdateContacts {
 940                            contacts: vec![invitee_contact.clone()],
 941                            ..Default::default()
 942                        },
 943                    )?;
 944                    self.peer.send(
 945                        connection_id,
 946                        proto::UpdateInviteInfo {
 947                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 948                            count: user.invite_count as u32,
 949                        },
 950                    )?;
 951                }
 952            }
 953        }
 954        Ok(())
 955    }
 956
 957    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 958        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 959            if let Some(invite_code) = &user.invite_code {
 960                let pool = self.connection_pool.lock();
 961                for connection_id in pool.user_connection_ids(user_id) {
 962                    self.peer.send(
 963                        connection_id,
 964                        proto::UpdateInviteInfo {
 965                            url: format!(
 966                                "{}{}",
 967                                self.app_state.config.invite_link_prefix, invite_code
 968                            ),
 969                            count: user.invite_count as u32,
 970                        },
 971                    )?;
 972                }
 973            }
 974        }
 975        Ok(())
 976    }
 977
 978    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 979        let pool = self.connection_pool.lock();
 980        for connection_id in pool.user_connection_ids(user_id) {
 981            self.peer
 982                .send(connection_id, proto::RefreshLlmToken {})
 983                .trace_err();
 984        }
 985    }
 986
 987    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 988        ServerSnapshot {
 989            connection_pool: ConnectionPoolGuard {
 990                guard: self.connection_pool.lock(),
 991                _not_send: PhantomData,
 992            },
 993            peer: &self.peer,
 994        }
 995    }
 996}
 997
 998impl Deref for ConnectionPoolGuard<'_> {
 999    type Target = ConnectionPool;
1000
1001    fn deref(&self) -> &Self::Target {
1002        &self.guard
1003    }
1004}
1005
1006impl DerefMut for ConnectionPoolGuard<'_> {
1007    fn deref_mut(&mut self) -> &mut Self::Target {
1008        &mut self.guard
1009    }
1010}
1011
1012impl Drop for ConnectionPoolGuard<'_> {
1013    fn drop(&mut self) {
1014        #[cfg(test)]
1015        self.check_invariants();
1016    }
1017}
1018
1019fn broadcast<F>(
1020    sender_id: Option<ConnectionId>,
1021    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1022    mut f: F,
1023) where
1024    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1025{
1026    for receiver_id in receiver_ids {
1027        if Some(receiver_id) != sender_id {
1028            if let Err(error) = f(receiver_id) {
1029                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1030            }
1031        }
1032    }
1033}
1034
1035pub struct ProtocolVersion(u32);
1036
1037impl Header for ProtocolVersion {
1038    fn name() -> &'static HeaderName {
1039        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1040        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1041    }
1042
1043    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1044    where
1045        Self: Sized,
1046        I: Iterator<Item = &'i axum::http::HeaderValue>,
1047    {
1048        let version = values
1049            .next()
1050            .ok_or_else(axum::headers::Error::invalid)?
1051            .to_str()
1052            .map_err(|_| axum::headers::Error::invalid())?
1053            .parse()
1054            .map_err(|_| axum::headers::Error::invalid())?;
1055        Ok(Self(version))
1056    }
1057
1058    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1059        values.extend([self.0.to_string().parse().unwrap()]);
1060    }
1061}
1062
1063pub struct AppVersionHeader(SemanticVersion);
1064impl Header for AppVersionHeader {
1065    fn name() -> &'static HeaderName {
1066        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1067        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1068    }
1069
1070    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1071    where
1072        Self: Sized,
1073        I: Iterator<Item = &'i axum::http::HeaderValue>,
1074    {
1075        let version = values
1076            .next()
1077            .ok_or_else(axum::headers::Error::invalid)?
1078            .to_str()
1079            .map_err(|_| axum::headers::Error::invalid())?
1080            .parse()
1081            .map_err(|_| axum::headers::Error::invalid())?;
1082        Ok(Self(version))
1083    }
1084
1085    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1086        values.extend([self.0.to_string().parse().unwrap()]);
1087    }
1088}
1089
1090pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1091    Router::new()
1092        .route("/rpc", get(handle_websocket_request))
1093        .layer(
1094            ServiceBuilder::new()
1095                .layer(Extension(server.app_state.clone()))
1096                .layer(middleware::from_fn(auth::validate_header)),
1097        )
1098        .route("/metrics", get(handle_metrics))
1099        .layer(Extension(server))
1100}
1101
1102pub async fn handle_websocket_request(
1103    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1104    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1105    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1106    Extension(server): Extension<Arc<Server>>,
1107    Extension(principal): Extension<Principal>,
1108    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110    ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112    if protocol_version != rpc::PROTOCOL_VERSION {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "client must be upgraded".to_string(),
1116        )
1117            .into_response();
1118    }
1119
1120    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121        return (
1122            StatusCode::UPGRADE_REQUIRED,
1123            "no version header found".to_string(),
1124        )
1125            .into_response();
1126    };
1127
1128    if !version.can_collaborate() {
1129        return (
1130            StatusCode::UPGRADE_REQUIRED,
1131            "client must be upgraded".to_string(),
1132        )
1133            .into_response();
1134    }
1135
1136    let socket_address = socket_address.to_string();
1137    ws.on_upgrade(move |socket| {
1138        let socket = socket
1139            .map_ok(to_tungstenite_message)
1140            .err_into()
1141            .with(|message| async move { to_axum_message(message) });
1142        let connection = Connection::new(Box::pin(socket));
1143        async move {
1144            server
1145                .handle_connection(
1146                    connection,
1147                    socket_address,
1148                    principal,
1149                    version,
1150                    country_code_header.map(|header| header.to_string()),
1151                    system_id_header.map(|header| header.to_string()),
1152                    None,
1153                    Executor::Production,
1154                )
1155                .await;
1156        }
1157    })
1158}
1159
1160pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1161    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1162    let connections_metric = CONNECTIONS_METRIC
1163        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1164
1165    let connections = server
1166        .connection_pool
1167        .lock()
1168        .connections()
1169        .filter(|connection| !connection.admin)
1170        .count();
1171    connections_metric.set(connections as _);
1172
1173    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1174    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1175        register_int_gauge!(
1176            "shared_projects",
1177            "number of open projects with one or more guests"
1178        )
1179        .unwrap()
1180    });
1181
1182    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1183    shared_projects_metric.set(shared_projects as _);
1184
1185    let encoder = prometheus::TextEncoder::new();
1186    let metric_families = prometheus::gather();
1187    let encoded_metrics = encoder
1188        .encode_to_string(&metric_families)
1189        .map_err(|err| anyhow!("{}", err))?;
1190    Ok(encoded_metrics)
1191}
1192
1193#[instrument(err, skip(executor))]
1194async fn connection_lost(
1195    session: Session,
1196    mut teardown: watch::Receiver<bool>,
1197    executor: Executor,
1198) -> Result<()> {
1199    session.peer.disconnect(session.connection_id);
1200    session
1201        .connection_pool()
1202        .await
1203        .remove_connection(session.connection_id)?;
1204
1205    session
1206        .db()
1207        .await
1208        .connection_lost(session.connection_id)
1209        .await
1210        .trace_err();
1211
1212    futures::select_biased! {
1213        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1214
1215            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1216            leave_room_for_session(&session, session.connection_id).await.trace_err();
1217            leave_channel_buffers_for_session(&session)
1218                .await
1219                .trace_err();
1220
1221            if !session
1222                .connection_pool()
1223                .await
1224                .is_user_online(session.user_id())
1225            {
1226                let db = session.db().await;
1227                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1228                    room_updated(&room, &session.peer);
1229                }
1230            }
1231
1232            update_user_contacts(session.user_id(), &session).await?;
1233        },
1234        _ = teardown.changed().fuse() => {}
1235    }
1236
1237    Ok(())
1238}
1239
1240/// Acknowledges a ping from a client, used to keep the connection alive.
1241async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1242    response.send(proto::Ack {})?;
1243    Ok(())
1244}
1245
1246/// Creates a new room for calling (outside of channels)
1247async fn create_room(
1248    _request: proto::CreateRoom,
1249    response: Response<proto::CreateRoom>,
1250    session: Session,
1251) -> Result<()> {
1252    let livekit_room = nanoid::nanoid!(30);
1253
1254    let live_kit_connection_info = util::maybe!(async {
1255        let live_kit = session.app_state.livekit_client.as_ref();
1256        let live_kit = live_kit?;
1257        let user_id = session.user_id().to_string();
1258
1259        let token = live_kit
1260            .room_token(&livekit_room, &user_id.to_string())
1261            .trace_err()?;
1262
1263        Some(proto::LiveKitConnectionInfo {
1264            server_url: live_kit.url().into(),
1265            token,
1266            can_publish: true,
1267        })
1268    })
1269    .await;
1270
1271    let room = session
1272        .db()
1273        .await
1274        .create_room(session.user_id(), session.connection_id, &livekit_room)
1275        .await?;
1276
1277    response.send(proto::CreateRoomResponse {
1278        room: Some(room.clone()),
1279        live_kit_connection_info,
1280    })?;
1281
1282    update_user_contacts(session.user_id(), &session).await?;
1283    Ok(())
1284}
1285
1286/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1287async fn join_room(
1288    request: proto::JoinRoom,
1289    response: Response<proto::JoinRoom>,
1290    session: Session,
1291) -> Result<()> {
1292    let room_id = RoomId::from_proto(request.id);
1293
1294    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1295
1296    if let Some(channel_id) = channel_id {
1297        return join_channel_internal(channel_id, Box::new(response), session).await;
1298    }
1299
1300    let joined_room = {
1301        let room = session
1302            .db()
1303            .await
1304            .join_room(room_id, session.user_id(), session.connection_id)
1305            .await?;
1306        room_updated(&room.room, &session.peer);
1307        room.into_inner()
1308    };
1309
1310    for connection_id in session
1311        .connection_pool()
1312        .await
1313        .user_connection_ids(session.user_id())
1314    {
1315        session
1316            .peer
1317            .send(
1318                connection_id,
1319                proto::CallCanceled {
1320                    room_id: room_id.to_proto(),
1321                },
1322            )
1323            .trace_err();
1324    }
1325
1326    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1327    {
1328        live_kit
1329            .room_token(
1330                &joined_room.room.livekit_room,
1331                &session.user_id().to_string(),
1332            )
1333            .trace_err()
1334            .map(|token| proto::LiveKitConnectionInfo {
1335                server_url: live_kit.url().into(),
1336                token,
1337                can_publish: true,
1338            })
1339    } else {
1340        None
1341    };
1342
1343    response.send(proto::JoinRoomResponse {
1344        room: Some(joined_room.room),
1345        channel_id: None,
1346        live_kit_connection_info,
1347    })?;
1348
1349    update_user_contacts(session.user_id(), &session).await?;
1350    Ok(())
1351}
1352
1353/// Rejoin room is used to reconnect to a room after connection errors.
1354async fn rejoin_room(
1355    request: proto::RejoinRoom,
1356    response: Response<proto::RejoinRoom>,
1357    session: Session,
1358) -> Result<()> {
1359    let room;
1360    let channel;
1361    {
1362        let mut rejoined_room = session
1363            .db()
1364            .await
1365            .rejoin_room(request, session.user_id(), session.connection_id)
1366            .await?;
1367
1368        response.send(proto::RejoinRoomResponse {
1369            room: Some(rejoined_room.room.clone()),
1370            reshared_projects: rejoined_room
1371                .reshared_projects
1372                .iter()
1373                .map(|project| proto::ResharedProject {
1374                    id: project.id.to_proto(),
1375                    collaborators: project
1376                        .collaborators
1377                        .iter()
1378                        .map(|collaborator| collaborator.to_proto())
1379                        .collect(),
1380                })
1381                .collect(),
1382            rejoined_projects: rejoined_room
1383                .rejoined_projects
1384                .iter()
1385                .map(|rejoined_project| rejoined_project.to_proto())
1386                .collect(),
1387        })?;
1388        room_updated(&rejoined_room.room, &session.peer);
1389
1390        for project in &rejoined_room.reshared_projects {
1391            for collaborator in &project.collaborators {
1392                session
1393                    .peer
1394                    .send(
1395                        collaborator.connection_id,
1396                        proto::UpdateProjectCollaborator {
1397                            project_id: project.id.to_proto(),
1398                            old_peer_id: Some(project.old_connection_id.into()),
1399                            new_peer_id: Some(session.connection_id.into()),
1400                        },
1401                    )
1402                    .trace_err();
1403            }
1404
1405            broadcast(
1406                Some(session.connection_id),
1407                project
1408                    .collaborators
1409                    .iter()
1410                    .map(|collaborator| collaborator.connection_id),
1411                |connection_id| {
1412                    session.peer.forward_send(
1413                        session.connection_id,
1414                        connection_id,
1415                        proto::UpdateProject {
1416                            project_id: project.id.to_proto(),
1417                            worktrees: project.worktrees.clone(),
1418                        },
1419                    )
1420                },
1421            );
1422        }
1423
1424        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1425
1426        let rejoined_room = rejoined_room.into_inner();
1427
1428        room = rejoined_room.room;
1429        channel = rejoined_room.channel;
1430    }
1431
1432    if let Some(channel) = channel {
1433        channel_updated(
1434            &channel,
1435            &room,
1436            &session.peer,
1437            &*session.connection_pool().await,
1438        );
1439    }
1440
1441    update_user_contacts(session.user_id(), &session).await?;
1442    Ok(())
1443}
1444
1445fn notify_rejoined_projects(
1446    rejoined_projects: &mut Vec<RejoinedProject>,
1447    session: &Session,
1448) -> Result<()> {
1449    for project in rejoined_projects.iter() {
1450        for collaborator in &project.collaborators {
1451            session
1452                .peer
1453                .send(
1454                    collaborator.connection_id,
1455                    proto::UpdateProjectCollaborator {
1456                        project_id: project.id.to_proto(),
1457                        old_peer_id: Some(project.old_connection_id.into()),
1458                        new_peer_id: Some(session.connection_id.into()),
1459                    },
1460                )
1461                .trace_err();
1462        }
1463    }
1464
1465    for project in rejoined_projects {
1466        for worktree in mem::take(&mut project.worktrees) {
1467            // Stream this worktree's entries.
1468            let message = proto::UpdateWorktree {
1469                project_id: project.id.to_proto(),
1470                worktree_id: worktree.id,
1471                abs_path: worktree.abs_path.clone(),
1472                root_name: worktree.root_name,
1473                updated_entries: worktree.updated_entries,
1474                removed_entries: worktree.removed_entries,
1475                scan_id: worktree.scan_id,
1476                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1477                updated_repositories: worktree.updated_repositories,
1478                removed_repositories: worktree.removed_repositories,
1479            };
1480            for update in proto::split_worktree_update(message) {
1481                session.peer.send(session.connection_id, update)?;
1482            }
1483
1484            // Stream this worktree's diagnostics.
1485            for summary in worktree.diagnostic_summaries {
1486                session.peer.send(
1487                    session.connection_id,
1488                    proto::UpdateDiagnosticSummary {
1489                        project_id: project.id.to_proto(),
1490                        worktree_id: worktree.id,
1491                        summary: Some(summary),
1492                    },
1493                )?;
1494            }
1495
1496            for settings_file in worktree.settings_files {
1497                session.peer.send(
1498                    session.connection_id,
1499                    proto::UpdateWorktreeSettings {
1500                        project_id: project.id.to_proto(),
1501                        worktree_id: worktree.id,
1502                        path: settings_file.path,
1503                        content: Some(settings_file.content),
1504                        kind: Some(settings_file.kind.to_proto().into()),
1505                    },
1506                )?;
1507            }
1508        }
1509
1510        for repository in mem::take(&mut project.updated_repositories) {
1511            for update in split_repository_update(repository) {
1512                session.peer.send(session.connection_id, update)?;
1513            }
1514        }
1515
1516        for id in mem::take(&mut project.removed_repositories) {
1517            session.peer.send(
1518                session.connection_id,
1519                proto::RemoveRepository {
1520                    project_id: project.id.to_proto(),
1521                    id,
1522                },
1523            )?;
1524        }
1525    }
1526
1527    Ok(())
1528}
1529
1530/// leave room disconnects from the room.
1531async fn leave_room(
1532    _: proto::LeaveRoom,
1533    response: Response<proto::LeaveRoom>,
1534    session: Session,
1535) -> Result<()> {
1536    leave_room_for_session(&session, session.connection_id).await?;
1537    response.send(proto::Ack {})?;
1538    Ok(())
1539}
1540
1541/// Updates the permissions of someone else in the room.
1542async fn set_room_participant_role(
1543    request: proto::SetRoomParticipantRole,
1544    response: Response<proto::SetRoomParticipantRole>,
1545    session: Session,
1546) -> Result<()> {
1547    let user_id = UserId::from_proto(request.user_id);
1548    let role = ChannelRole::from(request.role());
1549
1550    let (livekit_room, can_publish) = {
1551        let room = session
1552            .db()
1553            .await
1554            .set_room_participant_role(
1555                session.user_id(),
1556                RoomId::from_proto(request.room_id),
1557                user_id,
1558                role,
1559            )
1560            .await?;
1561
1562        let livekit_room = room.livekit_room.clone();
1563        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1564        room_updated(&room, &session.peer);
1565        (livekit_room, can_publish)
1566    };
1567
1568    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1569        live_kit
1570            .update_participant(
1571                livekit_room.clone(),
1572                request.user_id.to_string(),
1573                livekit_api::proto::ParticipantPermission {
1574                    can_subscribe: true,
1575                    can_publish,
1576                    can_publish_data: can_publish,
1577                    hidden: false,
1578                    recorder: false,
1579                },
1580            )
1581            .await
1582            .trace_err();
1583    }
1584
1585    response.send(proto::Ack {})?;
1586    Ok(())
1587}
1588
1589/// Call someone else into the current room
1590async fn call(
1591    request: proto::Call,
1592    response: Response<proto::Call>,
1593    session: Session,
1594) -> Result<()> {
1595    let room_id = RoomId::from_proto(request.room_id);
1596    let calling_user_id = session.user_id();
1597    let calling_connection_id = session.connection_id;
1598    let called_user_id = UserId::from_proto(request.called_user_id);
1599    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1600    if !session
1601        .db()
1602        .await
1603        .has_contact(calling_user_id, called_user_id)
1604        .await?
1605    {
1606        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1607    }
1608
1609    let incoming_call = {
1610        let (room, incoming_call) = &mut *session
1611            .db()
1612            .await
1613            .call(
1614                room_id,
1615                calling_user_id,
1616                calling_connection_id,
1617                called_user_id,
1618                initial_project_id,
1619            )
1620            .await?;
1621        room_updated(room, &session.peer);
1622        mem::take(incoming_call)
1623    };
1624    update_user_contacts(called_user_id, &session).await?;
1625
1626    let mut calls = session
1627        .connection_pool()
1628        .await
1629        .user_connection_ids(called_user_id)
1630        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1631        .collect::<FuturesUnordered<_>>();
1632
1633    while let Some(call_response) = calls.next().await {
1634        match call_response.as_ref() {
1635            Ok(_) => {
1636                response.send(proto::Ack {})?;
1637                return Ok(());
1638            }
1639            Err(_) => {
1640                call_response.trace_err();
1641            }
1642        }
1643    }
1644
1645    {
1646        let room = session
1647            .db()
1648            .await
1649            .call_failed(room_id, called_user_id)
1650            .await?;
1651        room_updated(&room, &session.peer);
1652    }
1653    update_user_contacts(called_user_id, &session).await?;
1654
1655    Err(anyhow!("failed to ring user"))?
1656}
1657
1658/// Cancel an outgoing call.
1659async fn cancel_call(
1660    request: proto::CancelCall,
1661    response: Response<proto::CancelCall>,
1662    session: Session,
1663) -> Result<()> {
1664    let called_user_id = UserId::from_proto(request.called_user_id);
1665    let room_id = RoomId::from_proto(request.room_id);
1666    {
1667        let room = session
1668            .db()
1669            .await
1670            .cancel_call(room_id, session.connection_id, called_user_id)
1671            .await?;
1672        room_updated(&room, &session.peer);
1673    }
1674
1675    for connection_id in session
1676        .connection_pool()
1677        .await
1678        .user_connection_ids(called_user_id)
1679    {
1680        session
1681            .peer
1682            .send(
1683                connection_id,
1684                proto::CallCanceled {
1685                    room_id: room_id.to_proto(),
1686                },
1687            )
1688            .trace_err();
1689    }
1690    response.send(proto::Ack {})?;
1691
1692    update_user_contacts(called_user_id, &session).await?;
1693    Ok(())
1694}
1695
1696/// Decline an incoming call.
1697async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1698    let room_id = RoomId::from_proto(message.room_id);
1699    {
1700        let room = session
1701            .db()
1702            .await
1703            .decline_call(Some(room_id), session.user_id())
1704            .await?
1705            .ok_or_else(|| anyhow!("failed to decline call"))?;
1706        room_updated(&room, &session.peer);
1707    }
1708
1709    for connection_id in session
1710        .connection_pool()
1711        .await
1712        .user_connection_ids(session.user_id())
1713    {
1714        session
1715            .peer
1716            .send(
1717                connection_id,
1718                proto::CallCanceled {
1719                    room_id: room_id.to_proto(),
1720                },
1721            )
1722            .trace_err();
1723    }
1724    update_user_contacts(session.user_id(), &session).await?;
1725    Ok(())
1726}
1727
1728/// Updates other participants in the room with your current location.
1729async fn update_participant_location(
1730    request: proto::UpdateParticipantLocation,
1731    response: Response<proto::UpdateParticipantLocation>,
1732    session: Session,
1733) -> Result<()> {
1734    let room_id = RoomId::from_proto(request.room_id);
1735    let location = request
1736        .location
1737        .ok_or_else(|| anyhow!("invalid location"))?;
1738
1739    let db = session.db().await;
1740    let room = db
1741        .update_room_participant_location(room_id, session.connection_id, location)
1742        .await?;
1743
1744    room_updated(&room, &session.peer);
1745    response.send(proto::Ack {})?;
1746    Ok(())
1747}
1748
1749/// Share a project into the room.
1750async fn share_project(
1751    request: proto::ShareProject,
1752    response: Response<proto::ShareProject>,
1753    session: Session,
1754) -> Result<()> {
1755    let (project_id, room) = &*session
1756        .db()
1757        .await
1758        .share_project(
1759            RoomId::from_proto(request.room_id),
1760            session.connection_id,
1761            &request.worktrees,
1762            request.is_ssh_project,
1763        )
1764        .await?;
1765    response.send(proto::ShareProjectResponse {
1766        project_id: project_id.to_proto(),
1767    })?;
1768    room_updated(room, &session.peer);
1769
1770    Ok(())
1771}
1772
1773/// Unshare a project from the room.
1774async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1775    let project_id = ProjectId::from_proto(message.project_id);
1776    unshare_project_internal(project_id, session.connection_id, &session).await
1777}
1778
1779async fn unshare_project_internal(
1780    project_id: ProjectId,
1781    connection_id: ConnectionId,
1782    session: &Session,
1783) -> Result<()> {
1784    let delete = {
1785        let room_guard = session
1786            .db()
1787            .await
1788            .unshare_project(project_id, connection_id)
1789            .await?;
1790
1791        let (delete, room, guest_connection_ids) = &*room_guard;
1792
1793        let message = proto::UnshareProject {
1794            project_id: project_id.to_proto(),
1795        };
1796
1797        broadcast(
1798            Some(connection_id),
1799            guest_connection_ids.iter().copied(),
1800            |conn_id| session.peer.send(conn_id, message.clone()),
1801        );
1802        if let Some(room) = room {
1803            room_updated(room, &session.peer);
1804        }
1805
1806        *delete
1807    };
1808
1809    if delete {
1810        let db = session.db().await;
1811        db.delete_project(project_id).await?;
1812    }
1813
1814    Ok(())
1815}
1816
1817/// Join someone elses shared project.
1818async fn join_project(
1819    request: proto::JoinProject,
1820    response: Response<proto::JoinProject>,
1821    session: Session,
1822) -> Result<()> {
1823    let project_id = ProjectId::from_proto(request.project_id);
1824
1825    tracing::info!(%project_id, "join project");
1826
1827    let db = session.db().await;
1828    let (project, replica_id) = &mut *db
1829        .join_project(project_id, session.connection_id, session.user_id())
1830        .await?;
1831    drop(db);
1832    tracing::info!(%project_id, "join remote project");
1833    join_project_internal(response, session, project, replica_id)
1834}
1835
1836trait JoinProjectInternalResponse {
1837    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1838}
1839impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1840    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1841        Response::<proto::JoinProject>::send(self, result)
1842    }
1843}
1844
1845fn join_project_internal(
1846    response: impl JoinProjectInternalResponse,
1847    session: Session,
1848    project: &mut Project,
1849    replica_id: &ReplicaId,
1850) -> Result<()> {
1851    let collaborators = project
1852        .collaborators
1853        .iter()
1854        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1855        .map(|collaborator| collaborator.to_proto())
1856        .collect::<Vec<_>>();
1857    let project_id = project.id;
1858    let guest_user_id = session.user_id();
1859
1860    let worktrees = project
1861        .worktrees
1862        .iter()
1863        .map(|(id, worktree)| proto::WorktreeMetadata {
1864            id: *id,
1865            root_name: worktree.root_name.clone(),
1866            visible: worktree.visible,
1867            abs_path: worktree.abs_path.clone(),
1868        })
1869        .collect::<Vec<_>>();
1870
1871    let add_project_collaborator = proto::AddProjectCollaborator {
1872        project_id: project_id.to_proto(),
1873        collaborator: Some(proto::Collaborator {
1874            peer_id: Some(session.connection_id.into()),
1875            replica_id: replica_id.0 as u32,
1876            user_id: guest_user_id.to_proto(),
1877            is_host: false,
1878        }),
1879    };
1880
1881    for collaborator in &collaborators {
1882        session
1883            .peer
1884            .send(
1885                collaborator.peer_id.unwrap().into(),
1886                add_project_collaborator.clone(),
1887            )
1888            .trace_err();
1889    }
1890
1891    // First, we send the metadata associated with each worktree.
1892    response.send(proto::JoinProjectResponse {
1893        project_id: project.id.0 as u64,
1894        worktrees: worktrees.clone(),
1895        replica_id: replica_id.0 as u32,
1896        collaborators: collaborators.clone(),
1897        language_servers: project.language_servers.clone(),
1898        role: project.role.into(),
1899    })?;
1900
1901    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1902        // Stream this worktree's entries.
1903        let message = proto::UpdateWorktree {
1904            project_id: project_id.to_proto(),
1905            worktree_id,
1906            abs_path: worktree.abs_path.clone(),
1907            root_name: worktree.root_name,
1908            updated_entries: worktree.entries,
1909            removed_entries: Default::default(),
1910            scan_id: worktree.scan_id,
1911            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1912            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1913            removed_repositories: Default::default(),
1914        };
1915        for update in proto::split_worktree_update(message) {
1916            session.peer.send(session.connection_id, update.clone())?;
1917        }
1918
1919        // Stream this worktree's diagnostics.
1920        for summary in worktree.diagnostic_summaries {
1921            session.peer.send(
1922                session.connection_id,
1923                proto::UpdateDiagnosticSummary {
1924                    project_id: project_id.to_proto(),
1925                    worktree_id: worktree.id,
1926                    summary: Some(summary),
1927                },
1928            )?;
1929        }
1930
1931        for settings_file in worktree.settings_files {
1932            session.peer.send(
1933                session.connection_id,
1934                proto::UpdateWorktreeSettings {
1935                    project_id: project_id.to_proto(),
1936                    worktree_id: worktree.id,
1937                    path: settings_file.path,
1938                    content: Some(settings_file.content),
1939                    kind: Some(settings_file.kind.to_proto() as i32),
1940                },
1941            )?;
1942        }
1943    }
1944
1945    for repository in mem::take(&mut project.repositories) {
1946        for update in split_repository_update(repository) {
1947            session.peer.send(session.connection_id, update)?;
1948        }
1949    }
1950
1951    for language_server in &project.language_servers {
1952        session.peer.send(
1953            session.connection_id,
1954            proto::UpdateLanguageServer {
1955                project_id: project_id.to_proto(),
1956                language_server_id: language_server.id,
1957                variant: Some(
1958                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1959                        proto::LspDiskBasedDiagnosticsUpdated {},
1960                    ),
1961                ),
1962            },
1963        )?;
1964    }
1965
1966    Ok(())
1967}
1968
1969/// Leave someone elses shared project.
1970async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1971    let sender_id = session.connection_id;
1972    let project_id = ProjectId::from_proto(request.project_id);
1973    let db = session.db().await;
1974
1975    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1976    tracing::info!(
1977        %project_id,
1978        "leave project"
1979    );
1980
1981    project_left(project, &session);
1982    if let Some(room) = room {
1983        room_updated(room, &session.peer);
1984    }
1985
1986    Ok(())
1987}
1988
1989/// Updates other participants with changes to the project
1990async fn update_project(
1991    request: proto::UpdateProject,
1992    response: Response<proto::UpdateProject>,
1993    session: Session,
1994) -> Result<()> {
1995    let project_id = ProjectId::from_proto(request.project_id);
1996    let (room, guest_connection_ids) = &*session
1997        .db()
1998        .await
1999        .update_project(project_id, session.connection_id, &request.worktrees)
2000        .await?;
2001    broadcast(
2002        Some(session.connection_id),
2003        guest_connection_ids.iter().copied(),
2004        |connection_id| {
2005            session
2006                .peer
2007                .forward_send(session.connection_id, connection_id, request.clone())
2008        },
2009    );
2010    if let Some(room) = room {
2011        room_updated(room, &session.peer);
2012    }
2013    response.send(proto::Ack {})?;
2014
2015    Ok(())
2016}
2017
2018/// Updates other participants with changes to the worktree
2019async fn update_worktree(
2020    request: proto::UpdateWorktree,
2021    response: Response<proto::UpdateWorktree>,
2022    session: Session,
2023) -> Result<()> {
2024    let guest_connection_ids = session
2025        .db()
2026        .await
2027        .update_worktree(&request, session.connection_id)
2028        .await?;
2029
2030    broadcast(
2031        Some(session.connection_id),
2032        guest_connection_ids.iter().copied(),
2033        |connection_id| {
2034            session
2035                .peer
2036                .forward_send(session.connection_id, connection_id, request.clone())
2037        },
2038    );
2039    response.send(proto::Ack {})?;
2040    Ok(())
2041}
2042
2043async fn update_repository(
2044    request: proto::UpdateRepository,
2045    response: Response<proto::UpdateRepository>,
2046    session: Session,
2047) -> Result<()> {
2048    let guest_connection_ids = session
2049        .db()
2050        .await
2051        .update_repository(&request, session.connection_id)
2052        .await?;
2053
2054    broadcast(
2055        Some(session.connection_id),
2056        guest_connection_ids.iter().copied(),
2057        |connection_id| {
2058            session
2059                .peer
2060                .forward_send(session.connection_id, connection_id, request.clone())
2061        },
2062    );
2063    response.send(proto::Ack {})?;
2064    Ok(())
2065}
2066
2067async fn remove_repository(
2068    request: proto::RemoveRepository,
2069    response: Response<proto::RemoveRepository>,
2070    session: Session,
2071) -> Result<()> {
2072    let guest_connection_ids = session
2073        .db()
2074        .await
2075        .remove_repository(&request, session.connection_id)
2076        .await?;
2077
2078    broadcast(
2079        Some(session.connection_id),
2080        guest_connection_ids.iter().copied(),
2081        |connection_id| {
2082            session
2083                .peer
2084                .forward_send(session.connection_id, connection_id, request.clone())
2085        },
2086    );
2087    response.send(proto::Ack {})?;
2088    Ok(())
2089}
2090
2091/// Updates other participants with changes to the diagnostics
2092async fn update_diagnostic_summary(
2093    message: proto::UpdateDiagnosticSummary,
2094    session: Session,
2095) -> Result<()> {
2096    let guest_connection_ids = session
2097        .db()
2098        .await
2099        .update_diagnostic_summary(&message, session.connection_id)
2100        .await?;
2101
2102    broadcast(
2103        Some(session.connection_id),
2104        guest_connection_ids.iter().copied(),
2105        |connection_id| {
2106            session
2107                .peer
2108                .forward_send(session.connection_id, connection_id, message.clone())
2109        },
2110    );
2111
2112    Ok(())
2113}
2114
2115/// Updates other participants with changes to the worktree settings
2116async fn update_worktree_settings(
2117    message: proto::UpdateWorktreeSettings,
2118    session: Session,
2119) -> Result<()> {
2120    let guest_connection_ids = session
2121        .db()
2122        .await
2123        .update_worktree_settings(&message, session.connection_id)
2124        .await?;
2125
2126    broadcast(
2127        Some(session.connection_id),
2128        guest_connection_ids.iter().copied(),
2129        |connection_id| {
2130            session
2131                .peer
2132                .forward_send(session.connection_id, connection_id, message.clone())
2133        },
2134    );
2135
2136    Ok(())
2137}
2138
2139/// Notify other participants that a language server has started.
2140async fn start_language_server(
2141    request: proto::StartLanguageServer,
2142    session: Session,
2143) -> Result<()> {
2144    let guest_connection_ids = session
2145        .db()
2146        .await
2147        .start_language_server(&request, session.connection_id)
2148        .await?;
2149
2150    broadcast(
2151        Some(session.connection_id),
2152        guest_connection_ids.iter().copied(),
2153        |connection_id| {
2154            session
2155                .peer
2156                .forward_send(session.connection_id, connection_id, request.clone())
2157        },
2158    );
2159    Ok(())
2160}
2161
2162/// Notify other participants that a language server has changed.
2163async fn update_language_server(
2164    request: proto::UpdateLanguageServer,
2165    session: Session,
2166) -> Result<()> {
2167    let project_id = ProjectId::from_proto(request.project_id);
2168    let project_connection_ids = session
2169        .db()
2170        .await
2171        .project_connection_ids(project_id, session.connection_id, true)
2172        .await?;
2173    broadcast(
2174        Some(session.connection_id),
2175        project_connection_ids.iter().copied(),
2176        |connection_id| {
2177            session
2178                .peer
2179                .forward_send(session.connection_id, connection_id, request.clone())
2180        },
2181    );
2182    Ok(())
2183}
2184
2185/// forward a project request to the host. These requests should be read only
2186/// as guests are allowed to send them.
2187async fn forward_read_only_project_request<T>(
2188    request: T,
2189    response: Response<T>,
2190    session: Session,
2191) -> Result<()>
2192where
2193    T: EntityMessage + RequestMessage,
2194{
2195    let project_id = ProjectId::from_proto(request.remote_entity_id());
2196    let host_connection_id = session
2197        .db()
2198        .await
2199        .host_for_read_only_project_request(project_id, session.connection_id)
2200        .await?;
2201    let payload = session
2202        .peer
2203        .forward_request(session.connection_id, host_connection_id, request)
2204        .await?;
2205    response.send(payload)?;
2206    Ok(())
2207}
2208
2209async fn forward_find_search_candidates_request(
2210    request: proto::FindSearchCandidates,
2211    response: Response<proto::FindSearchCandidates>,
2212    session: Session,
2213) -> Result<()> {
2214    let project_id = ProjectId::from_proto(request.remote_entity_id());
2215    let host_connection_id = session
2216        .db()
2217        .await
2218        .host_for_read_only_project_request(project_id, session.connection_id)
2219        .await?;
2220    let payload = session
2221        .peer
2222        .forward_request(session.connection_id, host_connection_id, request)
2223        .await?;
2224    response.send(payload)?;
2225    Ok(())
2226}
2227
2228/// forward a project request to the host. These requests are disallowed
2229/// for guests.
2230async fn forward_mutating_project_request<T>(
2231    request: T,
2232    response: Response<T>,
2233    session: Session,
2234) -> Result<()>
2235where
2236    T: EntityMessage + RequestMessage,
2237{
2238    let project_id = ProjectId::from_proto(request.remote_entity_id());
2239
2240    let host_connection_id = session
2241        .db()
2242        .await
2243        .host_for_mutating_project_request(project_id, session.connection_id)
2244        .await?;
2245    let payload = session
2246        .peer
2247        .forward_request(session.connection_id, host_connection_id, request)
2248        .await?;
2249    response.send(payload)?;
2250    Ok(())
2251}
2252
2253/// Notify other participants that a new buffer has been created
2254async fn create_buffer_for_peer(
2255    request: proto::CreateBufferForPeer,
2256    session: Session,
2257) -> Result<()> {
2258    session
2259        .db()
2260        .await
2261        .check_user_is_project_host(
2262            ProjectId::from_proto(request.project_id),
2263            session.connection_id,
2264        )
2265        .await?;
2266    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2267    session
2268        .peer
2269        .forward_send(session.connection_id, peer_id.into(), request)?;
2270    Ok(())
2271}
2272
2273/// Notify other participants that a buffer has been updated. This is
2274/// allowed for guests as long as the update is limited to selections.
2275async fn update_buffer(
2276    request: proto::UpdateBuffer,
2277    response: Response<proto::UpdateBuffer>,
2278    session: Session,
2279) -> Result<()> {
2280    let project_id = ProjectId::from_proto(request.project_id);
2281    let mut capability = Capability::ReadOnly;
2282
2283    for op in request.operations.iter() {
2284        match op.variant {
2285            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2286            Some(_) => capability = Capability::ReadWrite,
2287        }
2288    }
2289
2290    let host = {
2291        let guard = session
2292            .db()
2293            .await
2294            .connections_for_buffer_update(project_id, session.connection_id, capability)
2295            .await?;
2296
2297        let (host, guests) = &*guard;
2298
2299        broadcast(
2300            Some(session.connection_id),
2301            guests.clone(),
2302            |connection_id| {
2303                session
2304                    .peer
2305                    .forward_send(session.connection_id, connection_id, request.clone())
2306            },
2307        );
2308
2309        *host
2310    };
2311
2312    if host != session.connection_id {
2313        session
2314            .peer
2315            .forward_request(session.connection_id, host, request.clone())
2316            .await?;
2317    }
2318
2319    response.send(proto::Ack {})?;
2320    Ok(())
2321}
2322
2323async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2324    let project_id = ProjectId::from_proto(message.project_id);
2325
2326    let operation = message.operation.as_ref().context("invalid operation")?;
2327    let capability = match operation.variant.as_ref() {
2328        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2329            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2330                match buffer_op.variant {
2331                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2332                        Capability::ReadOnly
2333                    }
2334                    _ => Capability::ReadWrite,
2335                }
2336            } else {
2337                Capability::ReadWrite
2338            }
2339        }
2340        Some(_) => Capability::ReadWrite,
2341        None => Capability::ReadOnly,
2342    };
2343
2344    let guard = session
2345        .db()
2346        .await
2347        .connections_for_buffer_update(project_id, session.connection_id, capability)
2348        .await?;
2349
2350    let (host, guests) = &*guard;
2351
2352    broadcast(
2353        Some(session.connection_id),
2354        guests.iter().chain([host]).copied(),
2355        |connection_id| {
2356            session
2357                .peer
2358                .forward_send(session.connection_id, connection_id, message.clone())
2359        },
2360    );
2361
2362    Ok(())
2363}
2364
2365/// Notify other participants that a project has been updated.
2366async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2367    request: T,
2368    session: Session,
2369) -> Result<()> {
2370    let project_id = ProjectId::from_proto(request.remote_entity_id());
2371    let project_connection_ids = session
2372        .db()
2373        .await
2374        .project_connection_ids(project_id, session.connection_id, false)
2375        .await?;
2376
2377    broadcast(
2378        Some(session.connection_id),
2379        project_connection_ids.iter().copied(),
2380        |connection_id| {
2381            session
2382                .peer
2383                .forward_send(session.connection_id, connection_id, request.clone())
2384        },
2385    );
2386    Ok(())
2387}
2388
2389/// Start following another user in a call.
2390async fn follow(
2391    request: proto::Follow,
2392    response: Response<proto::Follow>,
2393    session: Session,
2394) -> Result<()> {
2395    let room_id = RoomId::from_proto(request.room_id);
2396    let project_id = request.project_id.map(ProjectId::from_proto);
2397    let leader_id = request
2398        .leader_id
2399        .ok_or_else(|| anyhow!("invalid leader id"))?
2400        .into();
2401    let follower_id = session.connection_id;
2402
2403    session
2404        .db()
2405        .await
2406        .check_room_participants(room_id, leader_id, session.connection_id)
2407        .await?;
2408
2409    let response_payload = session
2410        .peer
2411        .forward_request(session.connection_id, leader_id, request)
2412        .await?;
2413    response.send(response_payload)?;
2414
2415    if let Some(project_id) = project_id {
2416        let room = session
2417            .db()
2418            .await
2419            .follow(room_id, project_id, leader_id, follower_id)
2420            .await?;
2421        room_updated(&room, &session.peer);
2422    }
2423
2424    Ok(())
2425}
2426
2427/// Stop following another user in a call.
2428async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2429    let room_id = RoomId::from_proto(request.room_id);
2430    let project_id = request.project_id.map(ProjectId::from_proto);
2431    let leader_id = request
2432        .leader_id
2433        .ok_or_else(|| anyhow!("invalid leader id"))?
2434        .into();
2435    let follower_id = session.connection_id;
2436
2437    session
2438        .db()
2439        .await
2440        .check_room_participants(room_id, leader_id, session.connection_id)
2441        .await?;
2442
2443    session
2444        .peer
2445        .forward_send(session.connection_id, leader_id, request)?;
2446
2447    if let Some(project_id) = project_id {
2448        let room = session
2449            .db()
2450            .await
2451            .unfollow(room_id, project_id, leader_id, follower_id)
2452            .await?;
2453        room_updated(&room, &session.peer);
2454    }
2455
2456    Ok(())
2457}
2458
2459/// Notify everyone following you of your current location.
2460async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2461    let room_id = RoomId::from_proto(request.room_id);
2462    let database = session.db.lock().await;
2463
2464    let connection_ids = if let Some(project_id) = request.project_id {
2465        let project_id = ProjectId::from_proto(project_id);
2466        database
2467            .project_connection_ids(project_id, session.connection_id, true)
2468            .await?
2469    } else {
2470        database
2471            .room_connection_ids(room_id, session.connection_id)
2472            .await?
2473    };
2474
2475    // For now, don't send view update messages back to that view's current leader.
2476    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2477        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2478        _ => None,
2479    });
2480
2481    for connection_id in connection_ids.iter().cloned() {
2482        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2483            session
2484                .peer
2485                .forward_send(session.connection_id, connection_id, request.clone())?;
2486        }
2487    }
2488    Ok(())
2489}
2490
2491/// Get public data about users.
2492async fn get_users(
2493    request: proto::GetUsers,
2494    response: Response<proto::GetUsers>,
2495    session: Session,
2496) -> Result<()> {
2497    let user_ids = request
2498        .user_ids
2499        .into_iter()
2500        .map(UserId::from_proto)
2501        .collect();
2502    let users = session
2503        .db()
2504        .await
2505        .get_users_by_ids(user_ids)
2506        .await?
2507        .into_iter()
2508        .map(|user| proto::User {
2509            id: user.id.to_proto(),
2510            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2511            github_login: user.github_login,
2512            email: user.email_address,
2513            name: user.name,
2514        })
2515        .collect();
2516    response.send(proto::UsersResponse { users })?;
2517    Ok(())
2518}
2519
2520/// Search for users (to invite) buy Github login
2521async fn fuzzy_search_users(
2522    request: proto::FuzzySearchUsers,
2523    response: Response<proto::FuzzySearchUsers>,
2524    session: Session,
2525) -> Result<()> {
2526    let query = request.query;
2527    let users = match query.len() {
2528        0 => vec![],
2529        1 | 2 => session
2530            .db()
2531            .await
2532            .get_user_by_github_login(&query)
2533            .await?
2534            .into_iter()
2535            .collect(),
2536        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2537    };
2538    let users = users
2539        .into_iter()
2540        .filter(|user| user.id != session.user_id())
2541        .map(|user| proto::User {
2542            id: user.id.to_proto(),
2543            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2544            github_login: user.github_login,
2545            name: user.name,
2546            email: user.email_address,
2547        })
2548        .collect();
2549    response.send(proto::UsersResponse { users })?;
2550    Ok(())
2551}
2552
2553/// Send a contact request to another user.
2554async fn request_contact(
2555    request: proto::RequestContact,
2556    response: Response<proto::RequestContact>,
2557    session: Session,
2558) -> Result<()> {
2559    let requester_id = session.user_id();
2560    let responder_id = UserId::from_proto(request.responder_id);
2561    if requester_id == responder_id {
2562        return Err(anyhow!("cannot add yourself as a contact"))?;
2563    }
2564
2565    let notifications = session
2566        .db()
2567        .await
2568        .send_contact_request(requester_id, responder_id)
2569        .await?;
2570
2571    // Update outgoing contact requests of requester
2572    let mut update = proto::UpdateContacts::default();
2573    update.outgoing_requests.push(responder_id.to_proto());
2574    for connection_id in session
2575        .connection_pool()
2576        .await
2577        .user_connection_ids(requester_id)
2578    {
2579        session.peer.send(connection_id, update.clone())?;
2580    }
2581
2582    // Update incoming contact requests of responder
2583    let mut update = proto::UpdateContacts::default();
2584    update
2585        .incoming_requests
2586        .push(proto::IncomingContactRequest {
2587            requester_id: requester_id.to_proto(),
2588        });
2589    let connection_pool = session.connection_pool().await;
2590    for connection_id in connection_pool.user_connection_ids(responder_id) {
2591        session.peer.send(connection_id, update.clone())?;
2592    }
2593
2594    send_notifications(&connection_pool, &session.peer, notifications);
2595
2596    response.send(proto::Ack {})?;
2597    Ok(())
2598}
2599
2600/// Accept or decline a contact request
2601async fn respond_to_contact_request(
2602    request: proto::RespondToContactRequest,
2603    response: Response<proto::RespondToContactRequest>,
2604    session: Session,
2605) -> Result<()> {
2606    let responder_id = session.user_id();
2607    let requester_id = UserId::from_proto(request.requester_id);
2608    let db = session.db().await;
2609    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2610        db.dismiss_contact_notification(responder_id, requester_id)
2611            .await?;
2612    } else {
2613        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2614
2615        let notifications = db
2616            .respond_to_contact_request(responder_id, requester_id, accept)
2617            .await?;
2618        let requester_busy = db.is_user_busy(requester_id).await?;
2619        let responder_busy = db.is_user_busy(responder_id).await?;
2620
2621        let pool = session.connection_pool().await;
2622        // Update responder with new contact
2623        let mut update = proto::UpdateContacts::default();
2624        if accept {
2625            update
2626                .contacts
2627                .push(contact_for_user(requester_id, requester_busy, &pool));
2628        }
2629        update
2630            .remove_incoming_requests
2631            .push(requester_id.to_proto());
2632        for connection_id in pool.user_connection_ids(responder_id) {
2633            session.peer.send(connection_id, update.clone())?;
2634        }
2635
2636        // Update requester with new contact
2637        let mut update = proto::UpdateContacts::default();
2638        if accept {
2639            update
2640                .contacts
2641                .push(contact_for_user(responder_id, responder_busy, &pool));
2642        }
2643        update
2644            .remove_outgoing_requests
2645            .push(responder_id.to_proto());
2646
2647        for connection_id in pool.user_connection_ids(requester_id) {
2648            session.peer.send(connection_id, update.clone())?;
2649        }
2650
2651        send_notifications(&pool, &session.peer, notifications);
2652    }
2653
2654    response.send(proto::Ack {})?;
2655    Ok(())
2656}
2657
2658/// Remove a contact.
2659async fn remove_contact(
2660    request: proto::RemoveContact,
2661    response: Response<proto::RemoveContact>,
2662    session: Session,
2663) -> Result<()> {
2664    let requester_id = session.user_id();
2665    let responder_id = UserId::from_proto(request.user_id);
2666    let db = session.db().await;
2667    let (contact_accepted, deleted_notification_id) =
2668        db.remove_contact(requester_id, responder_id).await?;
2669
2670    let pool = session.connection_pool().await;
2671    // Update outgoing contact requests of requester
2672    let mut update = proto::UpdateContacts::default();
2673    if contact_accepted {
2674        update.remove_contacts.push(responder_id.to_proto());
2675    } else {
2676        update
2677            .remove_outgoing_requests
2678            .push(responder_id.to_proto());
2679    }
2680    for connection_id in pool.user_connection_ids(requester_id) {
2681        session.peer.send(connection_id, update.clone())?;
2682    }
2683
2684    // Update incoming contact requests of responder
2685    let mut update = proto::UpdateContacts::default();
2686    if contact_accepted {
2687        update.remove_contacts.push(requester_id.to_proto());
2688    } else {
2689        update
2690            .remove_incoming_requests
2691            .push(requester_id.to_proto());
2692    }
2693    for connection_id in pool.user_connection_ids(responder_id) {
2694        session.peer.send(connection_id, update.clone())?;
2695        if let Some(notification_id) = deleted_notification_id {
2696            session.peer.send(
2697                connection_id,
2698                proto::DeleteNotification {
2699                    notification_id: notification_id.to_proto(),
2700                },
2701            )?;
2702        }
2703    }
2704
2705    response.send(proto::Ack {})?;
2706    Ok(())
2707}
2708
2709fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2710    version.0.minor() < 139
2711}
2712
2713async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2714    let plan = session.current_plan(&session.db().await).await?;
2715
2716    session
2717        .peer
2718        .send(
2719            session.connection_id,
2720            proto::UpdateUserPlan { plan: plan.into() },
2721        )
2722        .trace_err();
2723
2724    Ok(())
2725}
2726
2727async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2728    subscribe_user_to_channels(session.user_id(), &session).await?;
2729    Ok(())
2730}
2731
2732async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2733    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2734    let mut pool = session.connection_pool().await;
2735    for membership in &channels_for_user.channel_memberships {
2736        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2737    }
2738    session.peer.send(
2739        session.connection_id,
2740        build_update_user_channels(&channels_for_user),
2741    )?;
2742    session.peer.send(
2743        session.connection_id,
2744        build_channels_update(channels_for_user),
2745    )?;
2746    Ok(())
2747}
2748
2749/// Creates a new channel.
2750async fn create_channel(
2751    request: proto::CreateChannel,
2752    response: Response<proto::CreateChannel>,
2753    session: Session,
2754) -> Result<()> {
2755    let db = session.db().await;
2756
2757    let parent_id = request.parent_id.map(ChannelId::from_proto);
2758    let (channel, membership) = db
2759        .create_channel(&request.name, parent_id, session.user_id())
2760        .await?;
2761
2762    let root_id = channel.root_id();
2763    let channel = Channel::from_model(channel);
2764
2765    response.send(proto::CreateChannelResponse {
2766        channel: Some(channel.to_proto()),
2767        parent_id: request.parent_id,
2768    })?;
2769
2770    let mut connection_pool = session.connection_pool().await;
2771    if let Some(membership) = membership {
2772        connection_pool.subscribe_to_channel(
2773            membership.user_id,
2774            membership.channel_id,
2775            membership.role,
2776        );
2777        let update = proto::UpdateUserChannels {
2778            channel_memberships: vec![proto::ChannelMembership {
2779                channel_id: membership.channel_id.to_proto(),
2780                role: membership.role.into(),
2781            }],
2782            ..Default::default()
2783        };
2784        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2785            session.peer.send(connection_id, update.clone())?;
2786        }
2787    }
2788
2789    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2790        if !role.can_see_channel(channel.visibility) {
2791            continue;
2792        }
2793
2794        let update = proto::UpdateChannels {
2795            channels: vec![channel.to_proto()],
2796            ..Default::default()
2797        };
2798        session.peer.send(connection_id, update.clone())?;
2799    }
2800
2801    Ok(())
2802}
2803
2804/// Delete a channel
2805async fn delete_channel(
2806    request: proto::DeleteChannel,
2807    response: Response<proto::DeleteChannel>,
2808    session: Session,
2809) -> Result<()> {
2810    let db = session.db().await;
2811
2812    let channel_id = request.channel_id;
2813    let (root_channel, removed_channels) = db
2814        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2815        .await?;
2816    response.send(proto::Ack {})?;
2817
2818    // Notify members of removed channels
2819    let mut update = proto::UpdateChannels::default();
2820    update
2821        .delete_channels
2822        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2823
2824    let connection_pool = session.connection_pool().await;
2825    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2826        session.peer.send(connection_id, update.clone())?;
2827    }
2828
2829    Ok(())
2830}
2831
2832/// Invite someone to join a channel.
2833async fn invite_channel_member(
2834    request: proto::InviteChannelMember,
2835    response: Response<proto::InviteChannelMember>,
2836    session: Session,
2837) -> Result<()> {
2838    let db = session.db().await;
2839    let channel_id = ChannelId::from_proto(request.channel_id);
2840    let invitee_id = UserId::from_proto(request.user_id);
2841    let InviteMemberResult {
2842        channel,
2843        notifications,
2844    } = db
2845        .invite_channel_member(
2846            channel_id,
2847            invitee_id,
2848            session.user_id(),
2849            request.role().into(),
2850        )
2851        .await?;
2852
2853    let update = proto::UpdateChannels {
2854        channel_invitations: vec![channel.to_proto()],
2855        ..Default::default()
2856    };
2857
2858    let connection_pool = session.connection_pool().await;
2859    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2860        session.peer.send(connection_id, update.clone())?;
2861    }
2862
2863    send_notifications(&connection_pool, &session.peer, notifications);
2864
2865    response.send(proto::Ack {})?;
2866    Ok(())
2867}
2868
2869/// remove someone from a channel
2870async fn remove_channel_member(
2871    request: proto::RemoveChannelMember,
2872    response: Response<proto::RemoveChannelMember>,
2873    session: Session,
2874) -> Result<()> {
2875    let db = session.db().await;
2876    let channel_id = ChannelId::from_proto(request.channel_id);
2877    let member_id = UserId::from_proto(request.user_id);
2878
2879    let RemoveChannelMemberResult {
2880        membership_update,
2881        notification_id,
2882    } = db
2883        .remove_channel_member(channel_id, member_id, session.user_id())
2884        .await?;
2885
2886    let mut connection_pool = session.connection_pool().await;
2887    notify_membership_updated(
2888        &mut connection_pool,
2889        membership_update,
2890        member_id,
2891        &session.peer,
2892    );
2893    for connection_id in connection_pool.user_connection_ids(member_id) {
2894        if let Some(notification_id) = notification_id {
2895            session
2896                .peer
2897                .send(
2898                    connection_id,
2899                    proto::DeleteNotification {
2900                        notification_id: notification_id.to_proto(),
2901                    },
2902                )
2903                .trace_err();
2904        }
2905    }
2906
2907    response.send(proto::Ack {})?;
2908    Ok(())
2909}
2910
2911/// Toggle the channel between public and private.
2912/// Care is taken to maintain the invariant that public channels only descend from public channels,
2913/// (though members-only channels can appear at any point in the hierarchy).
2914async fn set_channel_visibility(
2915    request: proto::SetChannelVisibility,
2916    response: Response<proto::SetChannelVisibility>,
2917    session: Session,
2918) -> Result<()> {
2919    let db = session.db().await;
2920    let channel_id = ChannelId::from_proto(request.channel_id);
2921    let visibility = request.visibility().into();
2922
2923    let channel_model = db
2924        .set_channel_visibility(channel_id, visibility, session.user_id())
2925        .await?;
2926    let root_id = channel_model.root_id();
2927    let channel = Channel::from_model(channel_model);
2928
2929    let mut connection_pool = session.connection_pool().await;
2930    for (user_id, role) in connection_pool
2931        .channel_user_ids(root_id)
2932        .collect::<Vec<_>>()
2933        .into_iter()
2934    {
2935        let update = if role.can_see_channel(channel.visibility) {
2936            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2937            proto::UpdateChannels {
2938                channels: vec![channel.to_proto()],
2939                ..Default::default()
2940            }
2941        } else {
2942            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2943            proto::UpdateChannels {
2944                delete_channels: vec![channel.id.to_proto()],
2945                ..Default::default()
2946            }
2947        };
2948
2949        for connection_id in connection_pool.user_connection_ids(user_id) {
2950            session.peer.send(connection_id, update.clone())?;
2951        }
2952    }
2953
2954    response.send(proto::Ack {})?;
2955    Ok(())
2956}
2957
2958/// Alter the role for a user in the channel.
2959async fn set_channel_member_role(
2960    request: proto::SetChannelMemberRole,
2961    response: Response<proto::SetChannelMemberRole>,
2962    session: Session,
2963) -> Result<()> {
2964    let db = session.db().await;
2965    let channel_id = ChannelId::from_proto(request.channel_id);
2966    let member_id = UserId::from_proto(request.user_id);
2967    let result = db
2968        .set_channel_member_role(
2969            channel_id,
2970            session.user_id(),
2971            member_id,
2972            request.role().into(),
2973        )
2974        .await?;
2975
2976    match result {
2977        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2978            let mut connection_pool = session.connection_pool().await;
2979            notify_membership_updated(
2980                &mut connection_pool,
2981                membership_update,
2982                member_id,
2983                &session.peer,
2984            )
2985        }
2986        db::SetMemberRoleResult::InviteUpdated(channel) => {
2987            let update = proto::UpdateChannels {
2988                channel_invitations: vec![channel.to_proto()],
2989                ..Default::default()
2990            };
2991
2992            for connection_id in session
2993                .connection_pool()
2994                .await
2995                .user_connection_ids(member_id)
2996            {
2997                session.peer.send(connection_id, update.clone())?;
2998            }
2999        }
3000    }
3001
3002    response.send(proto::Ack {})?;
3003    Ok(())
3004}
3005
3006/// Change the name of a channel
3007async fn rename_channel(
3008    request: proto::RenameChannel,
3009    response: Response<proto::RenameChannel>,
3010    session: Session,
3011) -> Result<()> {
3012    let db = session.db().await;
3013    let channel_id = ChannelId::from_proto(request.channel_id);
3014    let channel_model = db
3015        .rename_channel(channel_id, session.user_id(), &request.name)
3016        .await?;
3017    let root_id = channel_model.root_id();
3018    let channel = Channel::from_model(channel_model);
3019
3020    response.send(proto::RenameChannelResponse {
3021        channel: Some(channel.to_proto()),
3022    })?;
3023
3024    let connection_pool = session.connection_pool().await;
3025    let update = proto::UpdateChannels {
3026        channels: vec![channel.to_proto()],
3027        ..Default::default()
3028    };
3029    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3030        if role.can_see_channel(channel.visibility) {
3031            session.peer.send(connection_id, update.clone())?;
3032        }
3033    }
3034
3035    Ok(())
3036}
3037
3038/// Move a channel to a new parent.
3039async fn move_channel(
3040    request: proto::MoveChannel,
3041    response: Response<proto::MoveChannel>,
3042    session: Session,
3043) -> Result<()> {
3044    let channel_id = ChannelId::from_proto(request.channel_id);
3045    let to = ChannelId::from_proto(request.to);
3046
3047    let (root_id, channels) = session
3048        .db()
3049        .await
3050        .move_channel(channel_id, to, session.user_id())
3051        .await?;
3052
3053    let connection_pool = session.connection_pool().await;
3054    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055        let channels = channels
3056            .iter()
3057            .filter_map(|channel| {
3058                if role.can_see_channel(channel.visibility) {
3059                    Some(channel.to_proto())
3060                } else {
3061                    None
3062                }
3063            })
3064            .collect::<Vec<_>>();
3065        if channels.is_empty() {
3066            continue;
3067        }
3068
3069        let update = proto::UpdateChannels {
3070            channels,
3071            ..Default::default()
3072        };
3073
3074        session.peer.send(connection_id, update.clone())?;
3075    }
3076
3077    response.send(Ack {})?;
3078    Ok(())
3079}
3080
3081/// Get the list of channel members
3082async fn get_channel_members(
3083    request: proto::GetChannelMembers,
3084    response: Response<proto::GetChannelMembers>,
3085    session: Session,
3086) -> Result<()> {
3087    let db = session.db().await;
3088    let channel_id = ChannelId::from_proto(request.channel_id);
3089    let limit = if request.limit == 0 {
3090        u16::MAX as u64
3091    } else {
3092        request.limit
3093    };
3094    let (members, users) = db
3095        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3096        .await?;
3097    response.send(proto::GetChannelMembersResponse { members, users })?;
3098    Ok(())
3099}
3100
3101/// Accept or decline a channel invitation.
3102async fn respond_to_channel_invite(
3103    request: proto::RespondToChannelInvite,
3104    response: Response<proto::RespondToChannelInvite>,
3105    session: Session,
3106) -> Result<()> {
3107    let db = session.db().await;
3108    let channel_id = ChannelId::from_proto(request.channel_id);
3109    let RespondToChannelInvite {
3110        membership_update,
3111        notifications,
3112    } = db
3113        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3114        .await?;
3115
3116    let mut connection_pool = session.connection_pool().await;
3117    if let Some(membership_update) = membership_update {
3118        notify_membership_updated(
3119            &mut connection_pool,
3120            membership_update,
3121            session.user_id(),
3122            &session.peer,
3123        );
3124    } else {
3125        let update = proto::UpdateChannels {
3126            remove_channel_invitations: vec![channel_id.to_proto()],
3127            ..Default::default()
3128        };
3129
3130        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3131            session.peer.send(connection_id, update.clone())?;
3132        }
3133    };
3134
3135    send_notifications(&connection_pool, &session.peer, notifications);
3136
3137    response.send(proto::Ack {})?;
3138
3139    Ok(())
3140}
3141
3142/// Join the channels' room
3143async fn join_channel(
3144    request: proto::JoinChannel,
3145    response: Response<proto::JoinChannel>,
3146    session: Session,
3147) -> Result<()> {
3148    let channel_id = ChannelId::from_proto(request.channel_id);
3149    join_channel_internal(channel_id, Box::new(response), session).await
3150}
3151
3152trait JoinChannelInternalResponse {
3153    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3154}
3155impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3156    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3157        Response::<proto::JoinChannel>::send(self, result)
3158    }
3159}
3160impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3161    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3162        Response::<proto::JoinRoom>::send(self, result)
3163    }
3164}
3165
3166async fn join_channel_internal(
3167    channel_id: ChannelId,
3168    response: Box<impl JoinChannelInternalResponse>,
3169    session: Session,
3170) -> Result<()> {
3171    let joined_room = {
3172        let mut db = session.db().await;
3173        // If zed quits without leaving the room, and the user re-opens zed before the
3174        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3175        // room they were in.
3176        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3177            tracing::info!(
3178                stale_connection_id = %connection,
3179                "cleaning up stale connection",
3180            );
3181            drop(db);
3182            leave_room_for_session(&session, connection).await?;
3183            db = session.db().await;
3184        }
3185
3186        let (joined_room, membership_updated, role) = db
3187            .join_channel(channel_id, session.user_id(), session.connection_id)
3188            .await?;
3189
3190        let live_kit_connection_info =
3191            session
3192                .app_state
3193                .livekit_client
3194                .as_ref()
3195                .and_then(|live_kit| {
3196                    let (can_publish, token) = if role == ChannelRole::Guest {
3197                        (
3198                            false,
3199                            live_kit
3200                                .guest_token(
3201                                    &joined_room.room.livekit_room,
3202                                    &session.user_id().to_string(),
3203                                )
3204                                .trace_err()?,
3205                        )
3206                    } else {
3207                        (
3208                            true,
3209                            live_kit
3210                                .room_token(
3211                                    &joined_room.room.livekit_room,
3212                                    &session.user_id().to_string(),
3213                                )
3214                                .trace_err()?,
3215                        )
3216                    };
3217
3218                    Some(LiveKitConnectionInfo {
3219                        server_url: live_kit.url().into(),
3220                        token,
3221                        can_publish,
3222                    })
3223                });
3224
3225        response.send(proto::JoinRoomResponse {
3226            room: Some(joined_room.room.clone()),
3227            channel_id: joined_room
3228                .channel
3229                .as_ref()
3230                .map(|channel| channel.id.to_proto()),
3231            live_kit_connection_info,
3232        })?;
3233
3234        let mut connection_pool = session.connection_pool().await;
3235        if let Some(membership_updated) = membership_updated {
3236            notify_membership_updated(
3237                &mut connection_pool,
3238                membership_updated,
3239                session.user_id(),
3240                &session.peer,
3241            );
3242        }
3243
3244        room_updated(&joined_room.room, &session.peer);
3245
3246        joined_room
3247    };
3248
3249    channel_updated(
3250        &joined_room
3251            .channel
3252            .ok_or_else(|| anyhow!("channel not returned"))?,
3253        &joined_room.room,
3254        &session.peer,
3255        &*session.connection_pool().await,
3256    );
3257
3258    update_user_contacts(session.user_id(), &session).await?;
3259    Ok(())
3260}
3261
3262/// Start editing the channel notes
3263async fn join_channel_buffer(
3264    request: proto::JoinChannelBuffer,
3265    response: Response<proto::JoinChannelBuffer>,
3266    session: Session,
3267) -> Result<()> {
3268    let db = session.db().await;
3269    let channel_id = ChannelId::from_proto(request.channel_id);
3270
3271    let open_response = db
3272        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3273        .await?;
3274
3275    let collaborators = open_response.collaborators.clone();
3276    response.send(open_response)?;
3277
3278    let update = UpdateChannelBufferCollaborators {
3279        channel_id: channel_id.to_proto(),
3280        collaborators: collaborators.clone(),
3281    };
3282    channel_buffer_updated(
3283        session.connection_id,
3284        collaborators
3285            .iter()
3286            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3287        &update,
3288        &session.peer,
3289    );
3290
3291    Ok(())
3292}
3293
3294/// Edit the channel notes
3295async fn update_channel_buffer(
3296    request: proto::UpdateChannelBuffer,
3297    session: Session,
3298) -> Result<()> {
3299    let db = session.db().await;
3300    let channel_id = ChannelId::from_proto(request.channel_id);
3301
3302    let (collaborators, epoch, version) = db
3303        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3304        .await?;
3305
3306    channel_buffer_updated(
3307        session.connection_id,
3308        collaborators.clone(),
3309        &proto::UpdateChannelBuffer {
3310            channel_id: channel_id.to_proto(),
3311            operations: request.operations,
3312        },
3313        &session.peer,
3314    );
3315
3316    let pool = &*session.connection_pool().await;
3317
3318    let non_collaborators =
3319        pool.channel_connection_ids(channel_id)
3320            .filter_map(|(connection_id, _)| {
3321                if collaborators.contains(&connection_id) {
3322                    None
3323                } else {
3324                    Some(connection_id)
3325                }
3326            });
3327
3328    broadcast(None, non_collaborators, |peer_id| {
3329        session.peer.send(
3330            peer_id,
3331            proto::UpdateChannels {
3332                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3333                    channel_id: channel_id.to_proto(),
3334                    epoch: epoch as u64,
3335                    version: version.clone(),
3336                }],
3337                ..Default::default()
3338            },
3339        )
3340    });
3341
3342    Ok(())
3343}
3344
3345/// Rejoin the channel notes after a connection blip
3346async fn rejoin_channel_buffers(
3347    request: proto::RejoinChannelBuffers,
3348    response: Response<proto::RejoinChannelBuffers>,
3349    session: Session,
3350) -> Result<()> {
3351    let db = session.db().await;
3352    let buffers = db
3353        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3354        .await?;
3355
3356    for rejoined_buffer in &buffers {
3357        let collaborators_to_notify = rejoined_buffer
3358            .buffer
3359            .collaborators
3360            .iter()
3361            .filter_map(|c| Some(c.peer_id?.into()));
3362        channel_buffer_updated(
3363            session.connection_id,
3364            collaborators_to_notify,
3365            &proto::UpdateChannelBufferCollaborators {
3366                channel_id: rejoined_buffer.buffer.channel_id,
3367                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3368            },
3369            &session.peer,
3370        );
3371    }
3372
3373    response.send(proto::RejoinChannelBuffersResponse {
3374        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3375    })?;
3376
3377    Ok(())
3378}
3379
3380/// Stop editing the channel notes
3381async fn leave_channel_buffer(
3382    request: proto::LeaveChannelBuffer,
3383    response: Response<proto::LeaveChannelBuffer>,
3384    session: Session,
3385) -> Result<()> {
3386    let db = session.db().await;
3387    let channel_id = ChannelId::from_proto(request.channel_id);
3388
3389    let left_buffer = db
3390        .leave_channel_buffer(channel_id, session.connection_id)
3391        .await?;
3392
3393    response.send(Ack {})?;
3394
3395    channel_buffer_updated(
3396        session.connection_id,
3397        left_buffer.connections,
3398        &proto::UpdateChannelBufferCollaborators {
3399            channel_id: channel_id.to_proto(),
3400            collaborators: left_buffer.collaborators,
3401        },
3402        &session.peer,
3403    );
3404
3405    Ok(())
3406}
3407
3408fn channel_buffer_updated<T: EnvelopedMessage>(
3409    sender_id: ConnectionId,
3410    collaborators: impl IntoIterator<Item = ConnectionId>,
3411    message: &T,
3412    peer: &Peer,
3413) {
3414    broadcast(Some(sender_id), collaborators, |peer_id| {
3415        peer.send(peer_id, message.clone())
3416    });
3417}
3418
3419fn send_notifications(
3420    connection_pool: &ConnectionPool,
3421    peer: &Peer,
3422    notifications: db::NotificationBatch,
3423) {
3424    for (user_id, notification) in notifications {
3425        for connection_id in connection_pool.user_connection_ids(user_id) {
3426            if let Err(error) = peer.send(
3427                connection_id,
3428                proto::AddNotification {
3429                    notification: Some(notification.clone()),
3430                },
3431            ) {
3432                tracing::error!(
3433                    "failed to send notification to {:?} {}",
3434                    connection_id,
3435                    error
3436                );
3437            }
3438        }
3439    }
3440}
3441
3442/// Send a message to the channel
3443async fn send_channel_message(
3444    request: proto::SendChannelMessage,
3445    response: Response<proto::SendChannelMessage>,
3446    session: Session,
3447) -> Result<()> {
3448    // Validate the message body.
3449    let body = request.body.trim().to_string();
3450    if body.len() > MAX_MESSAGE_LEN {
3451        return Err(anyhow!("message is too long"))?;
3452    }
3453    if body.is_empty() {
3454        return Err(anyhow!("message can't be blank"))?;
3455    }
3456
3457    // TODO: adjust mentions if body is trimmed
3458
3459    let timestamp = OffsetDateTime::now_utc();
3460    let nonce = request
3461        .nonce
3462        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3463
3464    let channel_id = ChannelId::from_proto(request.channel_id);
3465    let CreatedChannelMessage {
3466        message_id,
3467        participant_connection_ids,
3468        notifications,
3469    } = session
3470        .db()
3471        .await
3472        .create_channel_message(
3473            channel_id,
3474            session.user_id(),
3475            &body,
3476            &request.mentions,
3477            timestamp,
3478            nonce.clone().into(),
3479            request.reply_to_message_id.map(MessageId::from_proto),
3480        )
3481        .await?;
3482
3483    let message = proto::ChannelMessage {
3484        sender_id: session.user_id().to_proto(),
3485        id: message_id.to_proto(),
3486        body,
3487        mentions: request.mentions,
3488        timestamp: timestamp.unix_timestamp() as u64,
3489        nonce: Some(nonce),
3490        reply_to_message_id: request.reply_to_message_id,
3491        edited_at: None,
3492    };
3493    broadcast(
3494        Some(session.connection_id),
3495        participant_connection_ids.clone(),
3496        |connection| {
3497            session.peer.send(
3498                connection,
3499                proto::ChannelMessageSent {
3500                    channel_id: channel_id.to_proto(),
3501                    message: Some(message.clone()),
3502                },
3503            )
3504        },
3505    );
3506    response.send(proto::SendChannelMessageResponse {
3507        message: Some(message),
3508    })?;
3509
3510    let pool = &*session.connection_pool().await;
3511    let non_participants =
3512        pool.channel_connection_ids(channel_id)
3513            .filter_map(|(connection_id, _)| {
3514                if participant_connection_ids.contains(&connection_id) {
3515                    None
3516                } else {
3517                    Some(connection_id)
3518                }
3519            });
3520    broadcast(None, non_participants, |peer_id| {
3521        session.peer.send(
3522            peer_id,
3523            proto::UpdateChannels {
3524                latest_channel_message_ids: vec![proto::ChannelMessageId {
3525                    channel_id: channel_id.to_proto(),
3526                    message_id: message_id.to_proto(),
3527                }],
3528                ..Default::default()
3529            },
3530        )
3531    });
3532    send_notifications(pool, &session.peer, notifications);
3533
3534    Ok(())
3535}
3536
3537/// Delete a channel message
3538async fn remove_channel_message(
3539    request: proto::RemoveChannelMessage,
3540    response: Response<proto::RemoveChannelMessage>,
3541    session: Session,
3542) -> Result<()> {
3543    let channel_id = ChannelId::from_proto(request.channel_id);
3544    let message_id = MessageId::from_proto(request.message_id);
3545    let (connection_ids, existing_notification_ids) = session
3546        .db()
3547        .await
3548        .remove_channel_message(channel_id, message_id, session.user_id())
3549        .await?;
3550
3551    broadcast(
3552        Some(session.connection_id),
3553        connection_ids,
3554        move |connection| {
3555            session.peer.send(connection, request.clone())?;
3556
3557            for notification_id in &existing_notification_ids {
3558                session.peer.send(
3559                    connection,
3560                    proto::DeleteNotification {
3561                        notification_id: (*notification_id).to_proto(),
3562                    },
3563                )?;
3564            }
3565
3566            Ok(())
3567        },
3568    );
3569    response.send(proto::Ack {})?;
3570    Ok(())
3571}
3572
3573async fn update_channel_message(
3574    request: proto::UpdateChannelMessage,
3575    response: Response<proto::UpdateChannelMessage>,
3576    session: Session,
3577) -> Result<()> {
3578    let channel_id = ChannelId::from_proto(request.channel_id);
3579    let message_id = MessageId::from_proto(request.message_id);
3580    let updated_at = OffsetDateTime::now_utc();
3581    let UpdatedChannelMessage {
3582        message_id,
3583        participant_connection_ids,
3584        notifications,
3585        reply_to_message_id,
3586        timestamp,
3587        deleted_mention_notification_ids,
3588        updated_mention_notifications,
3589    } = session
3590        .db()
3591        .await
3592        .update_channel_message(
3593            channel_id,
3594            message_id,
3595            session.user_id(),
3596            request.body.as_str(),
3597            &request.mentions,
3598            updated_at,
3599        )
3600        .await?;
3601
3602    let nonce = request
3603        .nonce
3604        .clone()
3605        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3606
3607    let message = proto::ChannelMessage {
3608        sender_id: session.user_id().to_proto(),
3609        id: message_id.to_proto(),
3610        body: request.body.clone(),
3611        mentions: request.mentions.clone(),
3612        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3613        nonce: Some(nonce),
3614        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3615        edited_at: Some(updated_at.unix_timestamp() as u64),
3616    };
3617
3618    response.send(proto::Ack {})?;
3619
3620    let pool = &*session.connection_pool().await;
3621    broadcast(
3622        Some(session.connection_id),
3623        participant_connection_ids,
3624        |connection| {
3625            session.peer.send(
3626                connection,
3627                proto::ChannelMessageUpdate {
3628                    channel_id: channel_id.to_proto(),
3629                    message: Some(message.clone()),
3630                },
3631            )?;
3632
3633            for notification_id in &deleted_mention_notification_ids {
3634                session.peer.send(
3635                    connection,
3636                    proto::DeleteNotification {
3637                        notification_id: (*notification_id).to_proto(),
3638                    },
3639                )?;
3640            }
3641
3642            for notification in &updated_mention_notifications {
3643                session.peer.send(
3644                    connection,
3645                    proto::UpdateNotification {
3646                        notification: Some(notification.clone()),
3647                    },
3648                )?;
3649            }
3650
3651            Ok(())
3652        },
3653    );
3654
3655    send_notifications(pool, &session.peer, notifications);
3656
3657    Ok(())
3658}
3659
3660/// Mark a channel message as read
3661async fn acknowledge_channel_message(
3662    request: proto::AckChannelMessage,
3663    session: Session,
3664) -> Result<()> {
3665    let channel_id = ChannelId::from_proto(request.channel_id);
3666    let message_id = MessageId::from_proto(request.message_id);
3667    let notifications = session
3668        .db()
3669        .await
3670        .observe_channel_message(channel_id, session.user_id(), message_id)
3671        .await?;
3672    send_notifications(
3673        &*session.connection_pool().await,
3674        &session.peer,
3675        notifications,
3676    );
3677    Ok(())
3678}
3679
3680/// Mark a buffer version as synced
3681async fn acknowledge_buffer_version(
3682    request: proto::AckBufferOperation,
3683    session: Session,
3684) -> Result<()> {
3685    let buffer_id = BufferId::from_proto(request.buffer_id);
3686    session
3687        .db()
3688        .await
3689        .observe_buffer_version(
3690            buffer_id,
3691            session.user_id(),
3692            request.epoch as i32,
3693            &request.version,
3694        )
3695        .await?;
3696    Ok(())
3697}
3698
3699async fn count_language_model_tokens(
3700    request: proto::CountLanguageModelTokens,
3701    response: Response<proto::CountLanguageModelTokens>,
3702    session: Session,
3703    config: &Config,
3704) -> Result<()> {
3705    authorize_access_to_legacy_llm_endpoints(&session).await?;
3706
3707    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3708        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3709        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3710    };
3711
3712    session
3713        .app_state
3714        .rate_limiter
3715        .check(&*rate_limit, session.user_id())
3716        .await?;
3717
3718    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3719        Some(proto::LanguageModelProvider::Google) => {
3720            let api_key = config
3721                .google_ai_api_key
3722                .as_ref()
3723                .context("no Google AI API key configured on the server")?;
3724            google_ai::count_tokens(
3725                session.http_client.as_ref(),
3726                google_ai::API_URL,
3727                api_key,
3728                serde_json::from_str(&request.request)?,
3729            )
3730            .await?
3731        }
3732        _ => return Err(anyhow!("unsupported provider"))?,
3733    };
3734
3735    response.send(proto::CountLanguageModelTokensResponse {
3736        token_count: result.total_tokens as u32,
3737    })?;
3738
3739    Ok(())
3740}
3741
3742struct ZedProCountLanguageModelTokensRateLimit;
3743
3744impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3745    fn capacity(&self) -> usize {
3746        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3747            .ok()
3748            .and_then(|v| v.parse().ok())
3749            .unwrap_or(600) // Picked arbitrarily
3750    }
3751
3752    fn refill_duration(&self) -> chrono::Duration {
3753        chrono::Duration::hours(1)
3754    }
3755
3756    fn db_name(&self) -> &'static str {
3757        "zed-pro:count-language-model-tokens"
3758    }
3759}
3760
3761struct FreeCountLanguageModelTokensRateLimit;
3762
3763impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3764    fn capacity(&self) -> usize {
3765        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3766            .ok()
3767            .and_then(|v| v.parse().ok())
3768            .unwrap_or(600 / 10) // Picked arbitrarily
3769    }
3770
3771    fn refill_duration(&self) -> chrono::Duration {
3772        chrono::Duration::hours(1)
3773    }
3774
3775    fn db_name(&self) -> &'static str {
3776        "free:count-language-model-tokens"
3777    }
3778}
3779
3780struct ZedProComputeEmbeddingsRateLimit;
3781
3782impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3783    fn capacity(&self) -> usize {
3784        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3785            .ok()
3786            .and_then(|v| v.parse().ok())
3787            .unwrap_or(5000) // Picked arbitrarily
3788    }
3789
3790    fn refill_duration(&self) -> chrono::Duration {
3791        chrono::Duration::hours(1)
3792    }
3793
3794    fn db_name(&self) -> &'static str {
3795        "zed-pro:compute-embeddings"
3796    }
3797}
3798
3799struct FreeComputeEmbeddingsRateLimit;
3800
3801impl RateLimit for FreeComputeEmbeddingsRateLimit {
3802    fn capacity(&self) -> usize {
3803        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3804            .ok()
3805            .and_then(|v| v.parse().ok())
3806            .unwrap_or(5000 / 10) // Picked arbitrarily
3807    }
3808
3809    fn refill_duration(&self) -> chrono::Duration {
3810        chrono::Duration::hours(1)
3811    }
3812
3813    fn db_name(&self) -> &'static str {
3814        "free:compute-embeddings"
3815    }
3816}
3817
3818async fn compute_embeddings(
3819    request: proto::ComputeEmbeddings,
3820    response: Response<proto::ComputeEmbeddings>,
3821    session: Session,
3822    api_key: Option<Arc<str>>,
3823) -> Result<()> {
3824    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3825    authorize_access_to_legacy_llm_endpoints(&session).await?;
3826
3827    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3828        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3829        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3830    };
3831
3832    session
3833        .app_state
3834        .rate_limiter
3835        .check(&*rate_limit, session.user_id())
3836        .await?;
3837
3838    let embeddings = match request.model.as_str() {
3839        "openai/text-embedding-3-small" => {
3840            open_ai::embed(
3841                session.http_client.as_ref(),
3842                OPEN_AI_API_URL,
3843                &api_key,
3844                OpenAiEmbeddingModel::TextEmbedding3Small,
3845                request.texts.iter().map(|text| text.as_str()),
3846            )
3847            .await?
3848        }
3849        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3850    };
3851
3852    let embeddings = request
3853        .texts
3854        .iter()
3855        .map(|text| {
3856            let mut hasher = sha2::Sha256::new();
3857            hasher.update(text.as_bytes());
3858            let result = hasher.finalize();
3859            result.to_vec()
3860        })
3861        .zip(
3862            embeddings
3863                .data
3864                .into_iter()
3865                .map(|embedding| embedding.embedding),
3866        )
3867        .collect::<HashMap<_, _>>();
3868
3869    let db = session.db().await;
3870    db.save_embeddings(&request.model, &embeddings)
3871        .await
3872        .context("failed to save embeddings")
3873        .trace_err();
3874
3875    response.send(proto::ComputeEmbeddingsResponse {
3876        embeddings: embeddings
3877            .into_iter()
3878            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3879            .collect(),
3880    })?;
3881    Ok(())
3882}
3883
3884async fn get_cached_embeddings(
3885    request: proto::GetCachedEmbeddings,
3886    response: Response<proto::GetCachedEmbeddings>,
3887    session: Session,
3888) -> Result<()> {
3889    authorize_access_to_legacy_llm_endpoints(&session).await?;
3890
3891    let db = session.db().await;
3892    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3893
3894    response.send(proto::GetCachedEmbeddingsResponse {
3895        embeddings: embeddings
3896            .into_iter()
3897            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3898            .collect(),
3899    })?;
3900    Ok(())
3901}
3902
3903/// This is leftover from before the LLM service.
3904///
3905/// The endpoints protected by this check will be moved there eventually.
3906async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3907    if session.is_staff() {
3908        Ok(())
3909    } else {
3910        Err(anyhow!("permission denied"))?
3911    }
3912}
3913
3914/// Get a Supermaven API key for the user
3915async fn get_supermaven_api_key(
3916    _request: proto::GetSupermavenApiKey,
3917    response: Response<proto::GetSupermavenApiKey>,
3918    session: Session,
3919) -> Result<()> {
3920    let user_id: String = session.user_id().to_string();
3921    if !session.is_staff() {
3922        return Err(anyhow!("supermaven not enabled for this account"))?;
3923    }
3924
3925    let email = session
3926        .email()
3927        .ok_or_else(|| anyhow!("user must have an email"))?;
3928
3929    let supermaven_admin_api = session
3930        .supermaven_client
3931        .as_ref()
3932        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3933
3934    let result = supermaven_admin_api
3935        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3936        .await?;
3937
3938    response.send(proto::GetSupermavenApiKeyResponse {
3939        api_key: result.api_key,
3940    })?;
3941
3942    Ok(())
3943}
3944
3945/// Start receiving chat updates for a channel
3946async fn join_channel_chat(
3947    request: proto::JoinChannelChat,
3948    response: Response<proto::JoinChannelChat>,
3949    session: Session,
3950) -> Result<()> {
3951    let channel_id = ChannelId::from_proto(request.channel_id);
3952
3953    let db = session.db().await;
3954    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3955        .await?;
3956    let messages = db
3957        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3958        .await?;
3959    response.send(proto::JoinChannelChatResponse {
3960        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3961        messages,
3962    })?;
3963    Ok(())
3964}
3965
3966/// Stop receiving chat updates for a channel
3967async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3968    let channel_id = ChannelId::from_proto(request.channel_id);
3969    session
3970        .db()
3971        .await
3972        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3973        .await?;
3974    Ok(())
3975}
3976
3977/// Retrieve the chat history for a channel
3978async fn get_channel_messages(
3979    request: proto::GetChannelMessages,
3980    response: Response<proto::GetChannelMessages>,
3981    session: Session,
3982) -> Result<()> {
3983    let channel_id = ChannelId::from_proto(request.channel_id);
3984    let messages = session
3985        .db()
3986        .await
3987        .get_channel_messages(
3988            channel_id,
3989            session.user_id(),
3990            MESSAGE_COUNT_PER_PAGE,
3991            Some(MessageId::from_proto(request.before_message_id)),
3992        )
3993        .await?;
3994    response.send(proto::GetChannelMessagesResponse {
3995        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3996        messages,
3997    })?;
3998    Ok(())
3999}
4000
4001/// Retrieve specific chat messages
4002async fn get_channel_messages_by_id(
4003    request: proto::GetChannelMessagesById,
4004    response: Response<proto::GetChannelMessagesById>,
4005    session: Session,
4006) -> Result<()> {
4007    let message_ids = request
4008        .message_ids
4009        .iter()
4010        .map(|id| MessageId::from_proto(*id))
4011        .collect::<Vec<_>>();
4012    let messages = session
4013        .db()
4014        .await
4015        .get_channel_messages_by_id(session.user_id(), &message_ids)
4016        .await?;
4017    response.send(proto::GetChannelMessagesResponse {
4018        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4019        messages,
4020    })?;
4021    Ok(())
4022}
4023
4024/// Retrieve the current users notifications
4025async fn get_notifications(
4026    request: proto::GetNotifications,
4027    response: Response<proto::GetNotifications>,
4028    session: Session,
4029) -> Result<()> {
4030    let notifications = session
4031        .db()
4032        .await
4033        .get_notifications(
4034            session.user_id(),
4035            NOTIFICATION_COUNT_PER_PAGE,
4036            request.before_id.map(db::NotificationId::from_proto),
4037        )
4038        .await?;
4039    response.send(proto::GetNotificationsResponse {
4040        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4041        notifications,
4042    })?;
4043    Ok(())
4044}
4045
4046/// Mark notifications as read
4047async fn mark_notification_as_read(
4048    request: proto::MarkNotificationRead,
4049    response: Response<proto::MarkNotificationRead>,
4050    session: Session,
4051) -> Result<()> {
4052    let database = &session.db().await;
4053    let notifications = database
4054        .mark_notification_as_read_by_id(
4055            session.user_id(),
4056            NotificationId::from_proto(request.notification_id),
4057        )
4058        .await?;
4059    send_notifications(
4060        &*session.connection_pool().await,
4061        &session.peer,
4062        notifications,
4063    );
4064    response.send(proto::Ack {})?;
4065    Ok(())
4066}
4067
4068/// Get the current users information
4069async fn get_private_user_info(
4070    _request: proto::GetPrivateUserInfo,
4071    response: Response<proto::GetPrivateUserInfo>,
4072    session: Session,
4073) -> Result<()> {
4074    let db = session.db().await;
4075
4076    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4077    let user = db
4078        .get_user_by_id(session.user_id())
4079        .await?
4080        .ok_or_else(|| anyhow!("user not found"))?;
4081    let flags = db.get_user_flags(session.user_id()).await?;
4082
4083    response.send(proto::GetPrivateUserInfoResponse {
4084        metrics_id,
4085        staff: user.admin,
4086        flags,
4087        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4088    })?;
4089    Ok(())
4090}
4091
4092/// Accept the terms of service (tos) on behalf of the current user
4093async fn accept_terms_of_service(
4094    _request: proto::AcceptTermsOfService,
4095    response: Response<proto::AcceptTermsOfService>,
4096    session: Session,
4097) -> Result<()> {
4098    let db = session.db().await;
4099
4100    let accepted_tos_at = Utc::now();
4101    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4102        .await?;
4103
4104    response.send(proto::AcceptTermsOfServiceResponse {
4105        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4106    })?;
4107    Ok(())
4108}
4109
4110/// The minimum account age an account must have in order to use the LLM service.
4111pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4112
4113async fn get_llm_api_token(
4114    _request: proto::GetLlmToken,
4115    response: Response<proto::GetLlmToken>,
4116    session: Session,
4117) -> Result<()> {
4118    let db = session.db().await;
4119
4120    let flags = db.get_user_flags(session.user_id()).await?;
4121    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4122
4123    if !session.is_staff() && !has_language_models_feature_flag {
4124        Err(anyhow!("permission denied"))?
4125    }
4126
4127    let user_id = session.user_id();
4128    let user = db
4129        .get_user_by_id(user_id)
4130        .await?
4131        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4132
4133    if user.accepted_tos_at.is_none() {
4134        Err(anyhow!("terms of service not accepted"))?
4135    }
4136
4137    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4138    let billing_preferences = db.get_billing_preferences(user.id).await?;
4139
4140    let token = LlmTokenClaims::create(
4141        &user,
4142        session.is_staff(),
4143        billing_preferences,
4144        &flags,
4145        has_llm_subscription,
4146        session.current_plan(&db).await?,
4147        session.system_id.clone(),
4148        &session.app_state.config,
4149    )?;
4150    response.send(proto::GetLlmTokenResponse { token })?;
4151    Ok(())
4152}
4153
4154fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4155    let message = match message {
4156        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4157        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4158        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4159        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4160        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4161            code: frame.code.into(),
4162            reason: frame.reason.as_str().to_owned().into(),
4163        })),
4164        // We should never receive a frame while reading the message, according
4165        // to the `tungstenite` maintainers:
4166        //
4167        // > It cannot occur when you read messages from the WebSocket, but it
4168        // > can be used when you want to send the raw frames (e.g. you want to
4169        // > send the frames to the WebSocket without composing the full message first).
4170        // >
4171        // > — https://github.com/snapview/tungstenite-rs/issues/268
4172        TungsteniteMessage::Frame(_) => {
4173            bail!("received an unexpected frame while reading the message")
4174        }
4175    };
4176
4177    Ok(message)
4178}
4179
4180fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4181    match message {
4182        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4183        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4184        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4185        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4186        AxumMessage::Close(frame) => {
4187            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4188                code: frame.code.into(),
4189                reason: frame.reason.as_ref().into(),
4190            }))
4191        }
4192    }
4193}
4194
4195fn notify_membership_updated(
4196    connection_pool: &mut ConnectionPool,
4197    result: MembershipUpdated,
4198    user_id: UserId,
4199    peer: &Peer,
4200) {
4201    for membership in &result.new_channels.channel_memberships {
4202        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4203    }
4204    for channel_id in &result.removed_channels {
4205        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4206    }
4207
4208    let user_channels_update = proto::UpdateUserChannels {
4209        channel_memberships: result
4210            .new_channels
4211            .channel_memberships
4212            .iter()
4213            .map(|cm| proto::ChannelMembership {
4214                channel_id: cm.channel_id.to_proto(),
4215                role: cm.role.into(),
4216            })
4217            .collect(),
4218        ..Default::default()
4219    };
4220
4221    let mut update = build_channels_update(result.new_channels);
4222    update.delete_channels = result
4223        .removed_channels
4224        .into_iter()
4225        .map(|id| id.to_proto())
4226        .collect();
4227    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4228
4229    for connection_id in connection_pool.user_connection_ids(user_id) {
4230        peer.send(connection_id, user_channels_update.clone())
4231            .trace_err();
4232        peer.send(connection_id, update.clone()).trace_err();
4233    }
4234}
4235
4236fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4237    proto::UpdateUserChannels {
4238        channel_memberships: channels
4239            .channel_memberships
4240            .iter()
4241            .map(|m| proto::ChannelMembership {
4242                channel_id: m.channel_id.to_proto(),
4243                role: m.role.into(),
4244            })
4245            .collect(),
4246        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4247        observed_channel_message_id: channels.observed_channel_messages.clone(),
4248    }
4249}
4250
4251fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4252    let mut update = proto::UpdateChannels::default();
4253
4254    for channel in channels.channels {
4255        update.channels.push(channel.to_proto());
4256    }
4257
4258    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4259    update.latest_channel_message_ids = channels.latest_channel_messages;
4260
4261    for (channel_id, participants) in channels.channel_participants {
4262        update
4263            .channel_participants
4264            .push(proto::ChannelParticipants {
4265                channel_id: channel_id.to_proto(),
4266                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4267            });
4268    }
4269
4270    for channel in channels.invited_channels {
4271        update.channel_invitations.push(channel.to_proto());
4272    }
4273
4274    update
4275}
4276
4277fn build_initial_contacts_update(
4278    contacts: Vec<db::Contact>,
4279    pool: &ConnectionPool,
4280) -> proto::UpdateContacts {
4281    let mut update = proto::UpdateContacts::default();
4282
4283    for contact in contacts {
4284        match contact {
4285            db::Contact::Accepted { user_id, busy } => {
4286                update.contacts.push(contact_for_user(user_id, busy, pool));
4287            }
4288            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4289            db::Contact::Incoming { user_id } => {
4290                update
4291                    .incoming_requests
4292                    .push(proto::IncomingContactRequest {
4293                        requester_id: user_id.to_proto(),
4294                    })
4295            }
4296        }
4297    }
4298
4299    update
4300}
4301
4302fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4303    proto::Contact {
4304        user_id: user_id.to_proto(),
4305        online: pool.is_user_online(user_id),
4306        busy,
4307    }
4308}
4309
4310fn room_updated(room: &proto::Room, peer: &Peer) {
4311    broadcast(
4312        None,
4313        room.participants
4314            .iter()
4315            .filter_map(|participant| Some(participant.peer_id?.into())),
4316        |peer_id| {
4317            peer.send(
4318                peer_id,
4319                proto::RoomUpdated {
4320                    room: Some(room.clone()),
4321                },
4322            )
4323        },
4324    );
4325}
4326
4327fn channel_updated(
4328    channel: &db::channel::Model,
4329    room: &proto::Room,
4330    peer: &Peer,
4331    pool: &ConnectionPool,
4332) {
4333    let participants = room
4334        .participants
4335        .iter()
4336        .map(|p| p.user_id)
4337        .collect::<Vec<_>>();
4338
4339    broadcast(
4340        None,
4341        pool.channel_connection_ids(channel.root_id())
4342            .filter_map(|(channel_id, role)| {
4343                role.can_see_channel(channel.visibility)
4344                    .then_some(channel_id)
4345            }),
4346        |peer_id| {
4347            peer.send(
4348                peer_id,
4349                proto::UpdateChannels {
4350                    channel_participants: vec![proto::ChannelParticipants {
4351                        channel_id: channel.id.to_proto(),
4352                        participant_user_ids: participants.clone(),
4353                    }],
4354                    ..Default::default()
4355                },
4356            )
4357        },
4358    );
4359}
4360
4361async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4362    let db = session.db().await;
4363
4364    let contacts = db.get_contacts(user_id).await?;
4365    let busy = db.is_user_busy(user_id).await?;
4366
4367    let pool = session.connection_pool().await;
4368    let updated_contact = contact_for_user(user_id, busy, &pool);
4369    for contact in contacts {
4370        if let db::Contact::Accepted {
4371            user_id: contact_user_id,
4372            ..
4373        } = contact
4374        {
4375            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4376                session
4377                    .peer
4378                    .send(
4379                        contact_conn_id,
4380                        proto::UpdateContacts {
4381                            contacts: vec![updated_contact.clone()],
4382                            remove_contacts: Default::default(),
4383                            incoming_requests: Default::default(),
4384                            remove_incoming_requests: Default::default(),
4385                            outgoing_requests: Default::default(),
4386                            remove_outgoing_requests: Default::default(),
4387                        },
4388                    )
4389                    .trace_err();
4390            }
4391        }
4392    }
4393    Ok(())
4394}
4395
4396async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4397    let mut contacts_to_update = HashSet::default();
4398
4399    let room_id;
4400    let canceled_calls_to_user_ids;
4401    let livekit_room;
4402    let delete_livekit_room;
4403    let room;
4404    let channel;
4405
4406    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4407        contacts_to_update.insert(session.user_id());
4408
4409        for project in left_room.left_projects.values() {
4410            project_left(project, session);
4411        }
4412
4413        room_id = RoomId::from_proto(left_room.room.id);
4414        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4415        livekit_room = mem::take(&mut left_room.room.livekit_room);
4416        delete_livekit_room = left_room.deleted;
4417        room = mem::take(&mut left_room.room);
4418        channel = mem::take(&mut left_room.channel);
4419
4420        room_updated(&room, &session.peer);
4421    } else {
4422        return Ok(());
4423    }
4424
4425    if let Some(channel) = channel {
4426        channel_updated(
4427            &channel,
4428            &room,
4429            &session.peer,
4430            &*session.connection_pool().await,
4431        );
4432    }
4433
4434    {
4435        let pool = session.connection_pool().await;
4436        for canceled_user_id in canceled_calls_to_user_ids {
4437            for connection_id in pool.user_connection_ids(canceled_user_id) {
4438                session
4439                    .peer
4440                    .send(
4441                        connection_id,
4442                        proto::CallCanceled {
4443                            room_id: room_id.to_proto(),
4444                        },
4445                    )
4446                    .trace_err();
4447            }
4448            contacts_to_update.insert(canceled_user_id);
4449        }
4450    }
4451
4452    for contact_user_id in contacts_to_update {
4453        update_user_contacts(contact_user_id, session).await?;
4454    }
4455
4456    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4457        live_kit
4458            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4459            .await
4460            .trace_err();
4461
4462        if delete_livekit_room {
4463            live_kit.delete_room(livekit_room).await.trace_err();
4464        }
4465    }
4466
4467    Ok(())
4468}
4469
4470async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4471    let left_channel_buffers = session
4472        .db()
4473        .await
4474        .leave_channel_buffers(session.connection_id)
4475        .await?;
4476
4477    for left_buffer in left_channel_buffers {
4478        channel_buffer_updated(
4479            session.connection_id,
4480            left_buffer.connections,
4481            &proto::UpdateChannelBufferCollaborators {
4482                channel_id: left_buffer.channel_id.to_proto(),
4483                collaborators: left_buffer.collaborators,
4484            },
4485            &session.peer,
4486        );
4487    }
4488
4489    Ok(())
4490}
4491
4492fn project_left(project: &db::LeftProject, session: &Session) {
4493    for connection_id in &project.connection_ids {
4494        if project.should_unshare {
4495            session
4496                .peer
4497                .send(
4498                    *connection_id,
4499                    proto::UnshareProject {
4500                        project_id: project.id.to_proto(),
4501                    },
4502                )
4503                .trace_err();
4504        } else {
4505            session
4506                .peer
4507                .send(
4508                    *connection_id,
4509                    proto::RemoveProjectCollaborator {
4510                        project_id: project.id.to_proto(),
4511                        peer_id: Some(session.connection_id.into()),
4512                    },
4513                )
4514                .trace_err();
4515        }
4516    }
4517}
4518
4519pub trait ResultExt {
4520    type Ok;
4521
4522    fn trace_err(self) -> Option<Self::Ok>;
4523}
4524
4525impl<T, E> ResultExt for Result<T, E>
4526where
4527    E: std::fmt::Debug,
4528{
4529    type Ok = T;
4530
4531    #[track_caller]
4532    fn trace_err(self) -> Option<T> {
4533        match self {
4534            Ok(value) => Some(value),
4535            Err(error) => {
4536                tracing::error!("{:?}", error);
4537                None
4538            }
4539        }
4540    }
4541}