rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use rpc::proto::split_repository_update;
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  46    TryStreamExt,
  47};
  48use prometheus::{register_int_gauge, IntGauge};
  49use rpc::{
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68        Arc, OnceLock,
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{watch, MutexGuard, Semaphore};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    field::{self},
  77    info_span, instrument, Instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 311            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 312            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 313            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 323            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 333            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 334            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 337            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 338            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 339            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 340            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 341            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 342            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 343            .add_request_handler(
 344                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 345            )
 346            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 347            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 348            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 349            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 350            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 351            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 352            .add_message_handler(create_buffer_for_peer)
 353            .add_request_handler(update_buffer)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 357            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 359            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 360            .add_request_handler(get_users)
 361            .add_request_handler(fuzzy_search_users)
 362            .add_request_handler(request_contact)
 363            .add_request_handler(remove_contact)
 364            .add_request_handler(respond_to_contact_request)
 365            .add_message_handler(subscribe_to_channels)
 366            .add_request_handler(create_channel)
 367            .add_request_handler(delete_channel)
 368            .add_request_handler(invite_channel_member)
 369            .add_request_handler(remove_channel_member)
 370            .add_request_handler(set_channel_member_role)
 371            .add_request_handler(set_channel_visibility)
 372            .add_request_handler(rename_channel)
 373            .add_request_handler(join_channel_buffer)
 374            .add_request_handler(leave_channel_buffer)
 375            .add_message_handler(update_channel_buffer)
 376            .add_request_handler(rejoin_channel_buffers)
 377            .add_request_handler(get_channel_members)
 378            .add_request_handler(respond_to_channel_invite)
 379            .add_request_handler(join_channel)
 380            .add_request_handler(join_channel_chat)
 381            .add_message_handler(leave_channel_chat)
 382            .add_request_handler(send_channel_message)
 383            .add_request_handler(remove_channel_message)
 384            .add_request_handler(update_channel_message)
 385            .add_request_handler(get_channel_messages)
 386            .add_request_handler(get_channel_messages_by_id)
 387            .add_request_handler(get_notifications)
 388            .add_request_handler(mark_notification_as_read)
 389            .add_request_handler(move_channel)
 390            .add_request_handler(follow)
 391            .add_message_handler(unfollow)
 392            .add_message_handler(update_followers)
 393            .add_request_handler(get_private_user_info)
 394            .add_request_handler(get_llm_api_token)
 395            .add_request_handler(accept_terms_of_service)
 396            .add_message_handler(acknowledge_channel_message)
 397            .add_message_handler(acknowledge_buffer_version)
 398            .add_request_handler(get_supermaven_api_key)
 399            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 400            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 401            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 402            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 403            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 404            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 405            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 408            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 409            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 410            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 411            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 413            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 415            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 416            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 417            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 418            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 419            .add_message_handler(update_context)
 420            .add_request_handler({
 421                let app_state = app_state.clone();
 422                move |request, response, session| {
 423                    let app_state = app_state.clone();
 424                    async move {
 425                        count_language_model_tokens(request, response, session, &app_state.config)
 426                            .await
 427                    }
 428                }
 429            })
 430            .add_request_handler(get_cached_embeddings)
 431            .add_request_handler({
 432                let app_state = app_state.clone();
 433                move |request, response, session| {
 434                    compute_embeddings(
 435                        request,
 436                        response,
 437                        session,
 438                        app_state.config.openai_api_key.clone(),
 439                    )
 440                }
 441            });
 442
 443        Arc::new(server)
 444    }
 445
 446    pub async fn start(&self) -> Result<()> {
 447        let server_id = *self.id.lock();
 448        let app_state = self.app_state.clone();
 449        let peer = self.peer.clone();
 450        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 451        let pool = self.connection_pool.clone();
 452        let livekit_client = self.app_state.livekit_client.clone();
 453
 454        let span = info_span!("start server");
 455        self.app_state.executor.spawn_detached(
 456            async move {
 457                tracing::info!("waiting for cleanup timeout");
 458                timeout.await;
 459                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 460                if let Some((room_ids, channel_ids)) = app_state
 461                    .db
 462                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 463                    .await
 464                    .trace_err()
 465                {
 466                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 467                    tracing::info!(
 468                        stale_channel_buffer_count = channel_ids.len(),
 469                        "retrieved stale channel buffers"
 470                    );
 471
 472                    for channel_id in channel_ids {
 473                        if let Some(refreshed_channel_buffer) = app_state
 474                            .db
 475                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 476                            .await
 477                            .trace_err()
 478                        {
 479                            for connection_id in refreshed_channel_buffer.connection_ids {
 480                                peer.send(
 481                                    connection_id,
 482                                    proto::UpdateChannelBufferCollaborators {
 483                                        channel_id: channel_id.to_proto(),
 484                                        collaborators: refreshed_channel_buffer
 485                                            .collaborators
 486                                            .clone(),
 487                                    },
 488                                )
 489                                .trace_err();
 490                            }
 491                        }
 492                    }
 493
 494                    for room_id in room_ids {
 495                        let mut contacts_to_update = HashSet::default();
 496                        let mut canceled_calls_to_user_ids = Vec::new();
 497                        let mut livekit_room = String::new();
 498                        let mut delete_livekit_room = false;
 499
 500                        if let Some(mut refreshed_room) = app_state
 501                            .db
 502                            .clear_stale_room_participants(room_id, server_id)
 503                            .await
 504                            .trace_err()
 505                        {
 506                            tracing::info!(
 507                                room_id = room_id.0,
 508                                new_participant_count = refreshed_room.room.participants.len(),
 509                                "refreshed room"
 510                            );
 511                            room_updated(&refreshed_room.room, &peer);
 512                            if let Some(channel) = refreshed_room.channel.as_ref() {
 513                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 514                            }
 515                            contacts_to_update
 516                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 517                            contacts_to_update
 518                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 519                            canceled_calls_to_user_ids =
 520                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 521                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 522                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 523                        }
 524
 525                        {
 526                            let pool = pool.lock();
 527                            for canceled_user_id in canceled_calls_to_user_ids {
 528                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 529                                    peer.send(
 530                                        connection_id,
 531                                        proto::CallCanceled {
 532                                            room_id: room_id.to_proto(),
 533                                        },
 534                                    )
 535                                    .trace_err();
 536                                }
 537                            }
 538                        }
 539
 540                        for user_id in contacts_to_update {
 541                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 542                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 543                            if let Some((busy, contacts)) = busy.zip(contacts) {
 544                                let pool = pool.lock();
 545                                let updated_contact = contact_for_user(user_id, busy, &pool);
 546                                for contact in contacts {
 547                                    if let db::Contact::Accepted {
 548                                        user_id: contact_user_id,
 549                                        ..
 550                                    } = contact
 551                                    {
 552                                        for contact_conn_id in
 553                                            pool.user_connection_ids(contact_user_id)
 554                                        {
 555                                            peer.send(
 556                                                contact_conn_id,
 557                                                proto::UpdateContacts {
 558                                                    contacts: vec![updated_contact.clone()],
 559                                                    remove_contacts: Default::default(),
 560                                                    incoming_requests: Default::default(),
 561                                                    remove_incoming_requests: Default::default(),
 562                                                    outgoing_requests: Default::default(),
 563                                                    remove_outgoing_requests: Default::default(),
 564                                                },
 565                                            )
 566                                            .trace_err();
 567                                        }
 568                                    }
 569                                }
 570                            }
 571                        }
 572
 573                        if let Some(live_kit) = livekit_client.as_ref() {
 574                            if delete_livekit_room {
 575                                live_kit.delete_room(livekit_room).await.trace_err();
 576                            }
 577                        }
 578                    }
 579                }
 580
 581                app_state
 582                    .db
 583                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 584                    .await
 585                    .trace_err();
 586            }
 587            .instrument(span),
 588        );
 589        Ok(())
 590    }
 591
 592    pub fn teardown(&self) {
 593        self.peer.teardown();
 594        self.connection_pool.lock().reset();
 595        let _ = self.teardown.send(true);
 596    }
 597
 598    #[cfg(test)]
 599    pub fn reset(&self, id: ServerId) {
 600        self.teardown();
 601        *self.id.lock() = id;
 602        self.peer.reset(id.0 as u32);
 603        let _ = self.teardown.send(false);
 604    }
 605
 606    #[cfg(test)]
 607    pub fn id(&self) -> ServerId {
 608        *self.id.lock()
 609    }
 610
 611    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 612    where
 613        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 614        Fut: 'static + Send + Future<Output = Result<()>>,
 615        M: EnvelopedMessage,
 616    {
 617        let prev_handler = self.handlers.insert(
 618            TypeId::of::<M>(),
 619            Box::new(move |envelope, session| {
 620                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 621                let received_at = envelope.received_at;
 622                tracing::info!("message received");
 623                let start_time = Instant::now();
 624                let future = (handler)(*envelope, session);
 625                async move {
 626                    let result = future.await;
 627                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 628                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 629                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 630                    let payload_type = M::NAME;
 631
 632                    match result {
 633                        Err(error) => {
 634                            tracing::error!(
 635                                ?error,
 636                                total_duration_ms,
 637                                processing_duration_ms,
 638                                queue_duration_ms,
 639                                payload_type,
 640                                "error handling message"
 641                            )
 642                        }
 643                        Ok(()) => tracing::info!(
 644                            total_duration_ms,
 645                            processing_duration_ms,
 646                            queue_duration_ms,
 647                            "finished handling message"
 648                        ),
 649                    }
 650                }
 651                .boxed()
 652            }),
 653        );
 654        if prev_handler.is_some() {
 655            panic!("registered a handler for the same message twice");
 656        }
 657        self
 658    }
 659
 660    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 661    where
 662        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 663        Fut: 'static + Send + Future<Output = Result<()>>,
 664        M: EnvelopedMessage,
 665    {
 666        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 667        self
 668    }
 669
 670    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 671    where
 672        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 673        Fut: Send + Future<Output = Result<()>>,
 674        M: RequestMessage,
 675    {
 676        let handler = Arc::new(handler);
 677        self.add_handler(move |envelope, session| {
 678            let receipt = envelope.receipt();
 679            let handler = handler.clone();
 680            async move {
 681                let peer = session.peer.clone();
 682                let responded = Arc::new(AtomicBool::default());
 683                let response = Response {
 684                    peer: peer.clone(),
 685                    responded: responded.clone(),
 686                    receipt,
 687                };
 688                match (handler)(envelope.payload, response, session).await {
 689                    Ok(()) => {
 690                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 691                            Ok(())
 692                        } else {
 693                            Err(anyhow!("handler did not send a response"))?
 694                        }
 695                    }
 696                    Err(error) => {
 697                        let proto_err = match &error {
 698                            Error::Internal(err) => err.to_proto(),
 699                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 700                        };
 701                        peer.respond_with_error(receipt, proto_err)?;
 702                        Err(error)
 703                    }
 704                }
 705            }
 706        })
 707    }
 708
 709    pub fn handle_connection(
 710        self: &Arc<Self>,
 711        connection: Connection,
 712        address: String,
 713        principal: Principal,
 714        zed_version: ZedVersion,
 715        geoip_country_code: Option<String>,
 716        system_id: Option<String>,
 717        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 718        executor: Executor,
 719    ) -> impl Future<Output = ()> {
 720        let this = self.clone();
 721        let span = info_span!("handle connection", %address,
 722            connection_id=field::Empty,
 723            user_id=field::Empty,
 724            login=field::Empty,
 725            impersonator=field::Empty,
 726            geoip_country_code=field::Empty
 727        );
 728        principal.update_span(&span);
 729        if let Some(country_code) = geoip_country_code.as_ref() {
 730            span.record("geoip_country_code", country_code);
 731        }
 732
 733        let mut teardown = self.teardown.subscribe();
 734        async move {
 735            if *teardown.borrow() {
 736                tracing::error!("server is tearing down");
 737                return
 738            }
 739            let (connection_id, handle_io, mut incoming_rx) = this
 740                .peer
 741                .add_connection(connection, {
 742                    let executor = executor.clone();
 743                    move |duration| executor.sleep(duration)
 744                });
 745            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 746
 747            tracing::info!("connection opened");
 748
 749            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 750            let http_client = match ReqwestClient::user_agent(&user_agent) {
 751                Ok(http_client) => Arc::new(http_client),
 752                Err(error) => {
 753                    tracing::error!(?error, "failed to create HTTP client");
 754                    return;
 755                }
 756            };
 757
 758            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 759                    supermaven_admin_api_key.to_string(),
 760                    http_client.clone(),
 761                )));
 762
 763            let session = Session {
 764                principal: principal.clone(),
 765                connection_id,
 766                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 767                peer: this.peer.clone(),
 768                connection_pool: this.connection_pool.clone(),
 769                app_state: this.app_state.clone(),
 770                http_client,
 771                geoip_country_code,
 772                system_id,
 773                _executor: executor.clone(),
 774                supermaven_client,
 775            };
 776
 777            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 778                tracing::error!(?error, "failed to send initial client update");
 779                return;
 780            }
 781
 782            let handle_io = handle_io.fuse();
 783            futures::pin_mut!(handle_io);
 784
 785            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 786            // This prevents deadlocks when e.g., client A performs a request to client B and
 787            // client B performs a request to client A. If both clients stop processing further
 788            // messages until their respective request completes, they won't have a chance to
 789            // respond to the other client's request and cause a deadlock.
 790            //
 791            // This arrangement ensures we will attempt to process earlier messages first, but fall
 792            // back to processing messages arrived later in the spirit of making progress.
 793            let mut foreground_message_handlers = FuturesUnordered::new();
 794            let concurrent_handlers = Arc::new(Semaphore::new(256));
 795            loop {
 796                let next_message = async {
 797                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 798                    let message = incoming_rx.next().await;
 799                    (permit, message)
 800                }.fuse();
 801                futures::pin_mut!(next_message);
 802                futures::select_biased! {
 803                    _ = teardown.changed().fuse() => return,
 804                    result = handle_io => {
 805                        if let Err(error) = result {
 806                            tracing::error!(?error, "error handling I/O");
 807                        }
 808                        break;
 809                    }
 810                    _ = foreground_message_handlers.next() => {}
 811                    next_message = next_message => {
 812                        let (permit, message) = next_message;
 813                        if let Some(message) = message {
 814                            let type_name = message.payload_type_name();
 815                            // note: we copy all the fields from the parent span so we can query them in the logs.
 816                            // (https://github.com/tokio-rs/tracing/issues/2670).
 817                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 818                                user_id=field::Empty,
 819                                login=field::Empty,
 820                                impersonator=field::Empty,
 821                            );
 822                            principal.update_span(&span);
 823                            let span_enter = span.enter();
 824                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 825                                let is_background = message.is_background();
 826                                let handle_message = (handler)(message, session.clone());
 827                                drop(span_enter);
 828
 829                                let handle_message = async move {
 830                                    handle_message.await;
 831                                    drop(permit);
 832                                }.instrument(span);
 833                                if is_background {
 834                                    executor.spawn_detached(handle_message);
 835                                } else {
 836                                    foreground_message_handlers.push(handle_message);
 837                                }
 838                            } else {
 839                                tracing::error!("no message handler");
 840                            }
 841                        } else {
 842                            tracing::info!("connection closed");
 843                            break;
 844                        }
 845                    }
 846                }
 847            }
 848
 849            drop(foreground_message_handlers);
 850            tracing::info!("signing out");
 851            if let Err(error) = connection_lost(session, teardown, executor).await {
 852                tracing::error!(?error, "error signing out");
 853            }
 854
 855        }.instrument(span)
 856    }
 857
 858    async fn send_initial_client_update(
 859        &self,
 860        connection_id: ConnectionId,
 861        principal: &Principal,
 862        zed_version: ZedVersion,
 863        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 864        session: &Session,
 865    ) -> Result<()> {
 866        self.peer.send(
 867            connection_id,
 868            proto::Hello {
 869                peer_id: Some(connection_id.into()),
 870            },
 871        )?;
 872        tracing::info!("sent hello message");
 873        if let Some(send_connection_id) = send_connection_id.take() {
 874            let _ = send_connection_id.send(connection_id);
 875        }
 876
 877        match principal {
 878            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 879                if !user.connected_once {
 880                    self.peer.send(connection_id, proto::ShowContacts {})?;
 881                    self.app_state
 882                        .db
 883                        .set_user_connected_once(user.id, true)
 884                        .await?;
 885                }
 886
 887                update_user_plan(user.id, session).await?;
 888
 889                let contacts = self.app_state.db.get_contacts(user.id).await?;
 890
 891                {
 892                    let mut pool = self.connection_pool.lock();
 893                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 894                    self.peer.send(
 895                        connection_id,
 896                        build_initial_contacts_update(contacts, &pool),
 897                    )?;
 898                }
 899
 900                if should_auto_subscribe_to_channels(zed_version) {
 901                    subscribe_user_to_channels(user.id, session).await?;
 902                }
 903
 904                if let Some(incoming_call) =
 905                    self.app_state.db.incoming_call_for_user(user.id).await?
 906                {
 907                    self.peer.send(connection_id, incoming_call)?;
 908                }
 909
 910                update_user_contacts(user.id, session).await?;
 911            }
 912        }
 913
 914        Ok(())
 915    }
 916
 917    pub async fn invite_code_redeemed(
 918        self: &Arc<Self>,
 919        inviter_id: UserId,
 920        invitee_id: UserId,
 921    ) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 923            if let Some(code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 926                for connection_id in pool.user_connection_ids(inviter_id) {
 927                    self.peer.send(
 928                        connection_id,
 929                        proto::UpdateContacts {
 930                            contacts: vec![invitee_contact.clone()],
 931                            ..Default::default()
 932                        },
 933                    )?;
 934                    self.peer.send(
 935                        connection_id,
 936                        proto::UpdateInviteInfo {
 937                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 938                            count: user.invite_count as u32,
 939                        },
 940                    )?;
 941                }
 942            }
 943        }
 944        Ok(())
 945    }
 946
 947    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 948        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 949            if let Some(invite_code) = &user.invite_code {
 950                let pool = self.connection_pool.lock();
 951                for connection_id in pool.user_connection_ids(user_id) {
 952                    self.peer.send(
 953                        connection_id,
 954                        proto::UpdateInviteInfo {
 955                            url: format!(
 956                                "{}{}",
 957                                self.app_state.config.invite_link_prefix, invite_code
 958                            ),
 959                            count: user.invite_count as u32,
 960                        },
 961                    )?;
 962                }
 963            }
 964        }
 965        Ok(())
 966    }
 967
 968    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 969        let pool = self.connection_pool.lock();
 970        for connection_id in pool.user_connection_ids(user_id) {
 971            self.peer
 972                .send(connection_id, proto::RefreshLlmToken {})
 973                .trace_err();
 974        }
 975    }
 976
 977    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 978        ServerSnapshot {
 979            connection_pool: ConnectionPoolGuard {
 980                guard: self.connection_pool.lock(),
 981                _not_send: PhantomData,
 982            },
 983            peer: &self.peer,
 984        }
 985    }
 986}
 987
 988impl Deref for ConnectionPoolGuard<'_> {
 989    type Target = ConnectionPool;
 990
 991    fn deref(&self) -> &Self::Target {
 992        &self.guard
 993    }
 994}
 995
 996impl DerefMut for ConnectionPoolGuard<'_> {
 997    fn deref_mut(&mut self) -> &mut Self::Target {
 998        &mut self.guard
 999    }
1000}
1001
1002impl Drop for ConnectionPoolGuard<'_> {
1003    fn drop(&mut self) {
1004        #[cfg(test)]
1005        self.check_invariants();
1006    }
1007}
1008
1009fn broadcast<F>(
1010    sender_id: Option<ConnectionId>,
1011    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1012    mut f: F,
1013) where
1014    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1015{
1016    for receiver_id in receiver_ids {
1017        if Some(receiver_id) != sender_id {
1018            if let Err(error) = f(receiver_id) {
1019                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1020            }
1021        }
1022    }
1023}
1024
1025pub struct ProtocolVersion(u32);
1026
1027impl Header for ProtocolVersion {
1028    fn name() -> &'static HeaderName {
1029        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1030        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1031    }
1032
1033    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1034    where
1035        Self: Sized,
1036        I: Iterator<Item = &'i axum::http::HeaderValue>,
1037    {
1038        let version = values
1039            .next()
1040            .ok_or_else(axum::headers::Error::invalid)?
1041            .to_str()
1042            .map_err(|_| axum::headers::Error::invalid())?
1043            .parse()
1044            .map_err(|_| axum::headers::Error::invalid())?;
1045        Ok(Self(version))
1046    }
1047
1048    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1049        values.extend([self.0.to_string().parse().unwrap()]);
1050    }
1051}
1052
1053pub struct AppVersionHeader(SemanticVersion);
1054impl Header for AppVersionHeader {
1055    fn name() -> &'static HeaderName {
1056        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1057        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1058    }
1059
1060    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1061    where
1062        Self: Sized,
1063        I: Iterator<Item = &'i axum::http::HeaderValue>,
1064    {
1065        let version = values
1066            .next()
1067            .ok_or_else(axum::headers::Error::invalid)?
1068            .to_str()
1069            .map_err(|_| axum::headers::Error::invalid())?
1070            .parse()
1071            .map_err(|_| axum::headers::Error::invalid())?;
1072        Ok(Self(version))
1073    }
1074
1075    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1076        values.extend([self.0.to_string().parse().unwrap()]);
1077    }
1078}
1079
1080pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1081    Router::new()
1082        .route("/rpc", get(handle_websocket_request))
1083        .layer(
1084            ServiceBuilder::new()
1085                .layer(Extension(server.app_state.clone()))
1086                .layer(middleware::from_fn(auth::validate_header)),
1087        )
1088        .route("/metrics", get(handle_metrics))
1089        .layer(Extension(server))
1090}
1091
1092pub async fn handle_websocket_request(
1093    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1094    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1095    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1096    Extension(server): Extension<Arc<Server>>,
1097    Extension(principal): Extension<Principal>,
1098    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1099    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1100    ws: WebSocketUpgrade,
1101) -> axum::response::Response {
1102    if protocol_version != rpc::PROTOCOL_VERSION {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "client must be upgraded".to_string(),
1106        )
1107            .into_response();
1108    }
1109
1110    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "no version header found".to_string(),
1114        )
1115            .into_response();
1116    };
1117
1118    if !version.can_collaborate() {
1119        return (
1120            StatusCode::UPGRADE_REQUIRED,
1121            "client must be upgraded".to_string(),
1122        )
1123            .into_response();
1124    }
1125
1126    let socket_address = socket_address.to_string();
1127    ws.on_upgrade(move |socket| {
1128        let socket = socket
1129            .map_ok(to_tungstenite_message)
1130            .err_into()
1131            .with(|message| async move { to_axum_message(message) });
1132        let connection = Connection::new(Box::pin(socket));
1133        async move {
1134            server
1135                .handle_connection(
1136                    connection,
1137                    socket_address,
1138                    principal,
1139                    version,
1140                    country_code_header.map(|header| header.to_string()),
1141                    system_id_header.map(|header| header.to_string()),
1142                    None,
1143                    Executor::Production,
1144                )
1145                .await;
1146        }
1147    })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152    let connections_metric = CONNECTIONS_METRIC
1153        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155    let connections = server
1156        .connection_pool
1157        .lock()
1158        .connections()
1159        .filter(|connection| !connection.admin)
1160        .count();
1161    connections_metric.set(connections as _);
1162
1163    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165        register_int_gauge!(
1166            "shared_projects",
1167            "number of open projects with one or more guests"
1168        )
1169        .unwrap()
1170    });
1171
1172    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173    shared_projects_metric.set(shared_projects as _);
1174
1175    let encoder = prometheus::TextEncoder::new();
1176    let metric_families = prometheus::gather();
1177    let encoded_metrics = encoder
1178        .encode_to_string(&metric_families)
1179        .map_err(|err| anyhow!("{}", err))?;
1180    Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185    session: Session,
1186    mut teardown: watch::Receiver<bool>,
1187    executor: Executor,
1188) -> Result<()> {
1189    session.peer.disconnect(session.connection_id);
1190    session
1191        .connection_pool()
1192        .await
1193        .remove_connection(session.connection_id)?;
1194
1195    session
1196        .db()
1197        .await
1198        .connection_lost(session.connection_id)
1199        .await
1200        .trace_err();
1201
1202    futures::select_biased! {
1203        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204
1205            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206            leave_room_for_session(&session, session.connection_id).await.trace_err();
1207            leave_channel_buffers_for_session(&session)
1208                .await
1209                .trace_err();
1210
1211            if !session
1212                .connection_pool()
1213                .await
1214                .is_user_online(session.user_id())
1215            {
1216                let db = session.db().await;
1217                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218                    room_updated(&room, &session.peer);
1219                }
1220            }
1221
1222            update_user_contacts(session.user_id(), &session).await?;
1223        },
1224        _ = teardown.changed().fuse() => {}
1225    }
1226
1227    Ok(())
1228}
1229
1230/// Acknowledges a ping from a client, used to keep the connection alive.
1231async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1232    response.send(proto::Ack {})?;
1233    Ok(())
1234}
1235
1236/// Creates a new room for calling (outside of channels)
1237async fn create_room(
1238    _request: proto::CreateRoom,
1239    response: Response<proto::CreateRoom>,
1240    session: Session,
1241) -> Result<()> {
1242    let livekit_room = nanoid::nanoid!(30);
1243
1244    let live_kit_connection_info = util::maybe!(async {
1245        let live_kit = session.app_state.livekit_client.as_ref();
1246        let live_kit = live_kit?;
1247        let user_id = session.user_id().to_string();
1248
1249        let token = live_kit
1250            .room_token(&livekit_room, &user_id.to_string())
1251            .trace_err()?;
1252
1253        Some(proto::LiveKitConnectionInfo {
1254            server_url: live_kit.url().into(),
1255            token,
1256            can_publish: true,
1257        })
1258    })
1259    .await;
1260
1261    let room = session
1262        .db()
1263        .await
1264        .create_room(session.user_id(), session.connection_id, &livekit_room)
1265        .await?;
1266
1267    response.send(proto::CreateRoomResponse {
1268        room: Some(room.clone()),
1269        live_kit_connection_info,
1270    })?;
1271
1272    update_user_contacts(session.user_id(), &session).await?;
1273    Ok(())
1274}
1275
1276/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1277async fn join_room(
1278    request: proto::JoinRoom,
1279    response: Response<proto::JoinRoom>,
1280    session: Session,
1281) -> Result<()> {
1282    let room_id = RoomId::from_proto(request.id);
1283
1284    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1285
1286    if let Some(channel_id) = channel_id {
1287        return join_channel_internal(channel_id, Box::new(response), session).await;
1288    }
1289
1290    let joined_room = {
1291        let room = session
1292            .db()
1293            .await
1294            .join_room(room_id, session.user_id(), session.connection_id)
1295            .await?;
1296        room_updated(&room.room, &session.peer);
1297        room.into_inner()
1298    };
1299
1300    for connection_id in session
1301        .connection_pool()
1302        .await
1303        .user_connection_ids(session.user_id())
1304    {
1305        session
1306            .peer
1307            .send(
1308                connection_id,
1309                proto::CallCanceled {
1310                    room_id: room_id.to_proto(),
1311                },
1312            )
1313            .trace_err();
1314    }
1315
1316    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1317    {
1318        live_kit
1319            .room_token(
1320                &joined_room.room.livekit_room,
1321                &session.user_id().to_string(),
1322            )
1323            .trace_err()
1324            .map(|token| proto::LiveKitConnectionInfo {
1325                server_url: live_kit.url().into(),
1326                token,
1327                can_publish: true,
1328            })
1329    } else {
1330        None
1331    };
1332
1333    response.send(proto::JoinRoomResponse {
1334        room: Some(joined_room.room),
1335        channel_id: None,
1336        live_kit_connection_info,
1337    })?;
1338
1339    update_user_contacts(session.user_id(), &session).await?;
1340    Ok(())
1341}
1342
1343/// Rejoin room is used to reconnect to a room after connection errors.
1344async fn rejoin_room(
1345    request: proto::RejoinRoom,
1346    response: Response<proto::RejoinRoom>,
1347    session: Session,
1348) -> Result<()> {
1349    let room;
1350    let channel;
1351    {
1352        let mut rejoined_room = session
1353            .db()
1354            .await
1355            .rejoin_room(request, session.user_id(), session.connection_id)
1356            .await?;
1357
1358        response.send(proto::RejoinRoomResponse {
1359            room: Some(rejoined_room.room.clone()),
1360            reshared_projects: rejoined_room
1361                .reshared_projects
1362                .iter()
1363                .map(|project| proto::ResharedProject {
1364                    id: project.id.to_proto(),
1365                    collaborators: project
1366                        .collaborators
1367                        .iter()
1368                        .map(|collaborator| collaborator.to_proto())
1369                        .collect(),
1370                })
1371                .collect(),
1372            rejoined_projects: rejoined_room
1373                .rejoined_projects
1374                .iter()
1375                .map(|rejoined_project| rejoined_project.to_proto())
1376                .collect(),
1377        })?;
1378        room_updated(&rejoined_room.room, &session.peer);
1379
1380        for project in &rejoined_room.reshared_projects {
1381            for collaborator in &project.collaborators {
1382                session
1383                    .peer
1384                    .send(
1385                        collaborator.connection_id,
1386                        proto::UpdateProjectCollaborator {
1387                            project_id: project.id.to_proto(),
1388                            old_peer_id: Some(project.old_connection_id.into()),
1389                            new_peer_id: Some(session.connection_id.into()),
1390                        },
1391                    )
1392                    .trace_err();
1393            }
1394
1395            broadcast(
1396                Some(session.connection_id),
1397                project
1398                    .collaborators
1399                    .iter()
1400                    .map(|collaborator| collaborator.connection_id),
1401                |connection_id| {
1402                    session.peer.forward_send(
1403                        session.connection_id,
1404                        connection_id,
1405                        proto::UpdateProject {
1406                            project_id: project.id.to_proto(),
1407                            worktrees: project.worktrees.clone(),
1408                        },
1409                    )
1410                },
1411            );
1412        }
1413
1414        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1415
1416        let rejoined_room = rejoined_room.into_inner();
1417
1418        room = rejoined_room.room;
1419        channel = rejoined_room.channel;
1420    }
1421
1422    if let Some(channel) = channel {
1423        channel_updated(
1424            &channel,
1425            &room,
1426            &session.peer,
1427            &*session.connection_pool().await,
1428        );
1429    }
1430
1431    update_user_contacts(session.user_id(), &session).await?;
1432    Ok(())
1433}
1434
1435fn notify_rejoined_projects(
1436    rejoined_projects: &mut Vec<RejoinedProject>,
1437    session: &Session,
1438) -> Result<()> {
1439    for project in rejoined_projects.iter() {
1440        for collaborator in &project.collaborators {
1441            session
1442                .peer
1443                .send(
1444                    collaborator.connection_id,
1445                    proto::UpdateProjectCollaborator {
1446                        project_id: project.id.to_proto(),
1447                        old_peer_id: Some(project.old_connection_id.into()),
1448                        new_peer_id: Some(session.connection_id.into()),
1449                    },
1450                )
1451                .trace_err();
1452        }
1453    }
1454
1455    for project in rejoined_projects {
1456        for worktree in mem::take(&mut project.worktrees) {
1457            // Stream this worktree's entries.
1458            let message = proto::UpdateWorktree {
1459                project_id: project.id.to_proto(),
1460                worktree_id: worktree.id,
1461                abs_path: worktree.abs_path.clone(),
1462                root_name: worktree.root_name,
1463                updated_entries: worktree.updated_entries,
1464                removed_entries: worktree.removed_entries,
1465                scan_id: worktree.scan_id,
1466                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1467                updated_repositories: worktree.updated_repositories,
1468                removed_repositories: worktree.removed_repositories,
1469            };
1470            for update in proto::split_worktree_update(message) {
1471                session.peer.send(session.connection_id, update)?;
1472            }
1473
1474            // Stream this worktree's diagnostics.
1475            for summary in worktree.diagnostic_summaries {
1476                session.peer.send(
1477                    session.connection_id,
1478                    proto::UpdateDiagnosticSummary {
1479                        project_id: project.id.to_proto(),
1480                        worktree_id: worktree.id,
1481                        summary: Some(summary),
1482                    },
1483                )?;
1484            }
1485
1486            for settings_file in worktree.settings_files {
1487                session.peer.send(
1488                    session.connection_id,
1489                    proto::UpdateWorktreeSettings {
1490                        project_id: project.id.to_proto(),
1491                        worktree_id: worktree.id,
1492                        path: settings_file.path,
1493                        content: Some(settings_file.content),
1494                        kind: Some(settings_file.kind.to_proto().into()),
1495                    },
1496                )?;
1497            }
1498        }
1499
1500        for repository in mem::take(&mut project.updated_repositories) {
1501            for update in split_repository_update(repository) {
1502                session.peer.send(session.connection_id, update)?;
1503            }
1504        }
1505
1506        for id in mem::take(&mut project.removed_repositories) {
1507            session.peer.send(
1508                session.connection_id,
1509                proto::RemoveRepository {
1510                    project_id: project.id.to_proto(),
1511                    id,
1512                },
1513            )?;
1514        }
1515    }
1516
1517    Ok(())
1518}
1519
1520/// leave room disconnects from the room.
1521async fn leave_room(
1522    _: proto::LeaveRoom,
1523    response: Response<proto::LeaveRoom>,
1524    session: Session,
1525) -> Result<()> {
1526    leave_room_for_session(&session, session.connection_id).await?;
1527    response.send(proto::Ack {})?;
1528    Ok(())
1529}
1530
1531/// Updates the permissions of someone else in the room.
1532async fn set_room_participant_role(
1533    request: proto::SetRoomParticipantRole,
1534    response: Response<proto::SetRoomParticipantRole>,
1535    session: Session,
1536) -> Result<()> {
1537    let user_id = UserId::from_proto(request.user_id);
1538    let role = ChannelRole::from(request.role());
1539
1540    let (livekit_room, can_publish) = {
1541        let room = session
1542            .db()
1543            .await
1544            .set_room_participant_role(
1545                session.user_id(),
1546                RoomId::from_proto(request.room_id),
1547                user_id,
1548                role,
1549            )
1550            .await?;
1551
1552        let livekit_room = room.livekit_room.clone();
1553        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1554        room_updated(&room, &session.peer);
1555        (livekit_room, can_publish)
1556    };
1557
1558    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1559        live_kit
1560            .update_participant(
1561                livekit_room.clone(),
1562                request.user_id.to_string(),
1563                livekit_api::proto::ParticipantPermission {
1564                    can_subscribe: true,
1565                    can_publish,
1566                    can_publish_data: can_publish,
1567                    hidden: false,
1568                    recorder: false,
1569                },
1570            )
1571            .await
1572            .trace_err();
1573    }
1574
1575    response.send(proto::Ack {})?;
1576    Ok(())
1577}
1578
1579/// Call someone else into the current room
1580async fn call(
1581    request: proto::Call,
1582    response: Response<proto::Call>,
1583    session: Session,
1584) -> Result<()> {
1585    let room_id = RoomId::from_proto(request.room_id);
1586    let calling_user_id = session.user_id();
1587    let calling_connection_id = session.connection_id;
1588    let called_user_id = UserId::from_proto(request.called_user_id);
1589    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1590    if !session
1591        .db()
1592        .await
1593        .has_contact(calling_user_id, called_user_id)
1594        .await?
1595    {
1596        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1597    }
1598
1599    let incoming_call = {
1600        let (room, incoming_call) = &mut *session
1601            .db()
1602            .await
1603            .call(
1604                room_id,
1605                calling_user_id,
1606                calling_connection_id,
1607                called_user_id,
1608                initial_project_id,
1609            )
1610            .await?;
1611        room_updated(room, &session.peer);
1612        mem::take(incoming_call)
1613    };
1614    update_user_contacts(called_user_id, &session).await?;
1615
1616    let mut calls = session
1617        .connection_pool()
1618        .await
1619        .user_connection_ids(called_user_id)
1620        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1621        .collect::<FuturesUnordered<_>>();
1622
1623    while let Some(call_response) = calls.next().await {
1624        match call_response.as_ref() {
1625            Ok(_) => {
1626                response.send(proto::Ack {})?;
1627                return Ok(());
1628            }
1629            Err(_) => {
1630                call_response.trace_err();
1631            }
1632        }
1633    }
1634
1635    {
1636        let room = session
1637            .db()
1638            .await
1639            .call_failed(room_id, called_user_id)
1640            .await?;
1641        room_updated(&room, &session.peer);
1642    }
1643    update_user_contacts(called_user_id, &session).await?;
1644
1645    Err(anyhow!("failed to ring user"))?
1646}
1647
1648/// Cancel an outgoing call.
1649async fn cancel_call(
1650    request: proto::CancelCall,
1651    response: Response<proto::CancelCall>,
1652    session: Session,
1653) -> Result<()> {
1654    let called_user_id = UserId::from_proto(request.called_user_id);
1655    let room_id = RoomId::from_proto(request.room_id);
1656    {
1657        let room = session
1658            .db()
1659            .await
1660            .cancel_call(room_id, session.connection_id, called_user_id)
1661            .await?;
1662        room_updated(&room, &session.peer);
1663    }
1664
1665    for connection_id in session
1666        .connection_pool()
1667        .await
1668        .user_connection_ids(called_user_id)
1669    {
1670        session
1671            .peer
1672            .send(
1673                connection_id,
1674                proto::CallCanceled {
1675                    room_id: room_id.to_proto(),
1676                },
1677            )
1678            .trace_err();
1679    }
1680    response.send(proto::Ack {})?;
1681
1682    update_user_contacts(called_user_id, &session).await?;
1683    Ok(())
1684}
1685
1686/// Decline an incoming call.
1687async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1688    let room_id = RoomId::from_proto(message.room_id);
1689    {
1690        let room = session
1691            .db()
1692            .await
1693            .decline_call(Some(room_id), session.user_id())
1694            .await?
1695            .ok_or_else(|| anyhow!("failed to decline call"))?;
1696        room_updated(&room, &session.peer);
1697    }
1698
1699    for connection_id in session
1700        .connection_pool()
1701        .await
1702        .user_connection_ids(session.user_id())
1703    {
1704        session
1705            .peer
1706            .send(
1707                connection_id,
1708                proto::CallCanceled {
1709                    room_id: room_id.to_proto(),
1710                },
1711            )
1712            .trace_err();
1713    }
1714    update_user_contacts(session.user_id(), &session).await?;
1715    Ok(())
1716}
1717
1718/// Updates other participants in the room with your current location.
1719async fn update_participant_location(
1720    request: proto::UpdateParticipantLocation,
1721    response: Response<proto::UpdateParticipantLocation>,
1722    session: Session,
1723) -> Result<()> {
1724    let room_id = RoomId::from_proto(request.room_id);
1725    let location = request
1726        .location
1727        .ok_or_else(|| anyhow!("invalid location"))?;
1728
1729    let db = session.db().await;
1730    let room = db
1731        .update_room_participant_location(room_id, session.connection_id, location)
1732        .await?;
1733
1734    room_updated(&room, &session.peer);
1735    response.send(proto::Ack {})?;
1736    Ok(())
1737}
1738
1739/// Share a project into the room.
1740async fn share_project(
1741    request: proto::ShareProject,
1742    response: Response<proto::ShareProject>,
1743    session: Session,
1744) -> Result<()> {
1745    let (project_id, room) = &*session
1746        .db()
1747        .await
1748        .share_project(
1749            RoomId::from_proto(request.room_id),
1750            session.connection_id,
1751            &request.worktrees,
1752            request.is_ssh_project,
1753        )
1754        .await?;
1755    response.send(proto::ShareProjectResponse {
1756        project_id: project_id.to_proto(),
1757    })?;
1758    room_updated(room, &session.peer);
1759
1760    Ok(())
1761}
1762
1763/// Unshare a project from the room.
1764async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1765    let project_id = ProjectId::from_proto(message.project_id);
1766    unshare_project_internal(project_id, session.connection_id, &session).await
1767}
1768
1769async fn unshare_project_internal(
1770    project_id: ProjectId,
1771    connection_id: ConnectionId,
1772    session: &Session,
1773) -> Result<()> {
1774    let delete = {
1775        let room_guard = session
1776            .db()
1777            .await
1778            .unshare_project(project_id, connection_id)
1779            .await?;
1780
1781        let (delete, room, guest_connection_ids) = &*room_guard;
1782
1783        let message = proto::UnshareProject {
1784            project_id: project_id.to_proto(),
1785        };
1786
1787        broadcast(
1788            Some(connection_id),
1789            guest_connection_ids.iter().copied(),
1790            |conn_id| session.peer.send(conn_id, message.clone()),
1791        );
1792        if let Some(room) = room {
1793            room_updated(room, &session.peer);
1794        }
1795
1796        *delete
1797    };
1798
1799    if delete {
1800        let db = session.db().await;
1801        db.delete_project(project_id).await?;
1802    }
1803
1804    Ok(())
1805}
1806
1807/// Join someone elses shared project.
1808async fn join_project(
1809    request: proto::JoinProject,
1810    response: Response<proto::JoinProject>,
1811    session: Session,
1812) -> Result<()> {
1813    let project_id = ProjectId::from_proto(request.project_id);
1814
1815    tracing::info!(%project_id, "join project");
1816
1817    let db = session.db().await;
1818    let (project, replica_id) = &mut *db
1819        .join_project(project_id, session.connection_id, session.user_id())
1820        .await?;
1821    drop(db);
1822    tracing::info!(%project_id, "join remote project");
1823    join_project_internal(response, session, project, replica_id)
1824}
1825
1826trait JoinProjectInternalResponse {
1827    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1828}
1829impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1830    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1831        Response::<proto::JoinProject>::send(self, result)
1832    }
1833}
1834
1835fn join_project_internal(
1836    response: impl JoinProjectInternalResponse,
1837    session: Session,
1838    project: &mut Project,
1839    replica_id: &ReplicaId,
1840) -> Result<()> {
1841    let collaborators = project
1842        .collaborators
1843        .iter()
1844        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1845        .map(|collaborator| collaborator.to_proto())
1846        .collect::<Vec<_>>();
1847    let project_id = project.id;
1848    let guest_user_id = session.user_id();
1849
1850    let worktrees = project
1851        .worktrees
1852        .iter()
1853        .map(|(id, worktree)| proto::WorktreeMetadata {
1854            id: *id,
1855            root_name: worktree.root_name.clone(),
1856            visible: worktree.visible,
1857            abs_path: worktree.abs_path.clone(),
1858        })
1859        .collect::<Vec<_>>();
1860
1861    let add_project_collaborator = proto::AddProjectCollaborator {
1862        project_id: project_id.to_proto(),
1863        collaborator: Some(proto::Collaborator {
1864            peer_id: Some(session.connection_id.into()),
1865            replica_id: replica_id.0 as u32,
1866            user_id: guest_user_id.to_proto(),
1867            is_host: false,
1868        }),
1869    };
1870
1871    for collaborator in &collaborators {
1872        session
1873            .peer
1874            .send(
1875                collaborator.peer_id.unwrap().into(),
1876                add_project_collaborator.clone(),
1877            )
1878            .trace_err();
1879    }
1880
1881    // First, we send the metadata associated with each worktree.
1882    response.send(proto::JoinProjectResponse {
1883        project_id: project.id.0 as u64,
1884        worktrees: worktrees.clone(),
1885        replica_id: replica_id.0 as u32,
1886        collaborators: collaborators.clone(),
1887        language_servers: project.language_servers.clone(),
1888        role: project.role.into(),
1889    })?;
1890
1891    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1892        // Stream this worktree's entries.
1893        let message = proto::UpdateWorktree {
1894            project_id: project_id.to_proto(),
1895            worktree_id,
1896            abs_path: worktree.abs_path.clone(),
1897            root_name: worktree.root_name,
1898            updated_entries: worktree.entries,
1899            removed_entries: Default::default(),
1900            scan_id: worktree.scan_id,
1901            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1902            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1903            removed_repositories: Default::default(),
1904        };
1905        for update in proto::split_worktree_update(message) {
1906            session.peer.send(session.connection_id, update.clone())?;
1907        }
1908
1909        // Stream this worktree's diagnostics.
1910        for summary in worktree.diagnostic_summaries {
1911            session.peer.send(
1912                session.connection_id,
1913                proto::UpdateDiagnosticSummary {
1914                    project_id: project_id.to_proto(),
1915                    worktree_id: worktree.id,
1916                    summary: Some(summary),
1917                },
1918            )?;
1919        }
1920
1921        for settings_file in worktree.settings_files {
1922            session.peer.send(
1923                session.connection_id,
1924                proto::UpdateWorktreeSettings {
1925                    project_id: project_id.to_proto(),
1926                    worktree_id: worktree.id,
1927                    path: settings_file.path,
1928                    content: Some(settings_file.content),
1929                    kind: Some(settings_file.kind.to_proto() as i32),
1930                },
1931            )?;
1932        }
1933    }
1934
1935    for repository in mem::take(&mut project.repositories) {
1936        for update in split_repository_update(repository) {
1937            session.peer.send(session.connection_id, update)?;
1938        }
1939    }
1940
1941    for language_server in &project.language_servers {
1942        session.peer.send(
1943            session.connection_id,
1944            proto::UpdateLanguageServer {
1945                project_id: project_id.to_proto(),
1946                language_server_id: language_server.id,
1947                variant: Some(
1948                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1949                        proto::LspDiskBasedDiagnosticsUpdated {},
1950                    ),
1951                ),
1952            },
1953        )?;
1954    }
1955
1956    Ok(())
1957}
1958
1959/// Leave someone elses shared project.
1960async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1961    let sender_id = session.connection_id;
1962    let project_id = ProjectId::from_proto(request.project_id);
1963    let db = session.db().await;
1964
1965    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1966    tracing::info!(
1967        %project_id,
1968        "leave project"
1969    );
1970
1971    project_left(project, &session);
1972    if let Some(room) = room {
1973        room_updated(room, &session.peer);
1974    }
1975
1976    Ok(())
1977}
1978
1979/// Updates other participants with changes to the project
1980async fn update_project(
1981    request: proto::UpdateProject,
1982    response: Response<proto::UpdateProject>,
1983    session: Session,
1984) -> Result<()> {
1985    let project_id = ProjectId::from_proto(request.project_id);
1986    let (room, guest_connection_ids) = &*session
1987        .db()
1988        .await
1989        .update_project(project_id, session.connection_id, &request.worktrees)
1990        .await?;
1991    broadcast(
1992        Some(session.connection_id),
1993        guest_connection_ids.iter().copied(),
1994        |connection_id| {
1995            session
1996                .peer
1997                .forward_send(session.connection_id, connection_id, request.clone())
1998        },
1999    );
2000    if let Some(room) = room {
2001        room_updated(room, &session.peer);
2002    }
2003    response.send(proto::Ack {})?;
2004
2005    Ok(())
2006}
2007
2008/// Updates other participants with changes to the worktree
2009async fn update_worktree(
2010    request: proto::UpdateWorktree,
2011    response: Response<proto::UpdateWorktree>,
2012    session: Session,
2013) -> Result<()> {
2014    let guest_connection_ids = session
2015        .db()
2016        .await
2017        .update_worktree(&request, session.connection_id)
2018        .await?;
2019
2020    broadcast(
2021        Some(session.connection_id),
2022        guest_connection_ids.iter().copied(),
2023        |connection_id| {
2024            session
2025                .peer
2026                .forward_send(session.connection_id, connection_id, request.clone())
2027        },
2028    );
2029    response.send(proto::Ack {})?;
2030    Ok(())
2031}
2032
2033async fn update_repository(
2034    request: proto::UpdateRepository,
2035    response: Response<proto::UpdateRepository>,
2036    session: Session,
2037) -> Result<()> {
2038    let guest_connection_ids = session
2039        .db()
2040        .await
2041        .update_repository(&request, session.connection_id)
2042        .await?;
2043
2044    broadcast(
2045        Some(session.connection_id),
2046        guest_connection_ids.iter().copied(),
2047        |connection_id| {
2048            session
2049                .peer
2050                .forward_send(session.connection_id, connection_id, request.clone())
2051        },
2052    );
2053    response.send(proto::Ack {})?;
2054    Ok(())
2055}
2056
2057async fn remove_repository(
2058    request: proto::RemoveRepository,
2059    response: Response<proto::RemoveRepository>,
2060    session: Session,
2061) -> Result<()> {
2062    let guest_connection_ids = session
2063        .db()
2064        .await
2065        .remove_repository(&request, session.connection_id)
2066        .await?;
2067
2068    broadcast(
2069        Some(session.connection_id),
2070        guest_connection_ids.iter().copied(),
2071        |connection_id| {
2072            session
2073                .peer
2074                .forward_send(session.connection_id, connection_id, request.clone())
2075        },
2076    );
2077    response.send(proto::Ack {})?;
2078    Ok(())
2079}
2080
2081/// Updates other participants with changes to the diagnostics
2082async fn update_diagnostic_summary(
2083    message: proto::UpdateDiagnosticSummary,
2084    session: Session,
2085) -> Result<()> {
2086    let guest_connection_ids = session
2087        .db()
2088        .await
2089        .update_diagnostic_summary(&message, session.connection_id)
2090        .await?;
2091
2092    broadcast(
2093        Some(session.connection_id),
2094        guest_connection_ids.iter().copied(),
2095        |connection_id| {
2096            session
2097                .peer
2098                .forward_send(session.connection_id, connection_id, message.clone())
2099        },
2100    );
2101
2102    Ok(())
2103}
2104
2105/// Updates other participants with changes to the worktree settings
2106async fn update_worktree_settings(
2107    message: proto::UpdateWorktreeSettings,
2108    session: Session,
2109) -> Result<()> {
2110    let guest_connection_ids = session
2111        .db()
2112        .await
2113        .update_worktree_settings(&message, session.connection_id)
2114        .await?;
2115
2116    broadcast(
2117        Some(session.connection_id),
2118        guest_connection_ids.iter().copied(),
2119        |connection_id| {
2120            session
2121                .peer
2122                .forward_send(session.connection_id, connection_id, message.clone())
2123        },
2124    );
2125
2126    Ok(())
2127}
2128
2129/// Notify other participants that a language server has started.
2130async fn start_language_server(
2131    request: proto::StartLanguageServer,
2132    session: Session,
2133) -> Result<()> {
2134    let guest_connection_ids = session
2135        .db()
2136        .await
2137        .start_language_server(&request, session.connection_id)
2138        .await?;
2139
2140    broadcast(
2141        Some(session.connection_id),
2142        guest_connection_ids.iter().copied(),
2143        |connection_id| {
2144            session
2145                .peer
2146                .forward_send(session.connection_id, connection_id, request.clone())
2147        },
2148    );
2149    Ok(())
2150}
2151
2152/// Notify other participants that a language server has changed.
2153async fn update_language_server(
2154    request: proto::UpdateLanguageServer,
2155    session: Session,
2156) -> Result<()> {
2157    let project_id = ProjectId::from_proto(request.project_id);
2158    let project_connection_ids = session
2159        .db()
2160        .await
2161        .project_connection_ids(project_id, session.connection_id, true)
2162        .await?;
2163    broadcast(
2164        Some(session.connection_id),
2165        project_connection_ids.iter().copied(),
2166        |connection_id| {
2167            session
2168                .peer
2169                .forward_send(session.connection_id, connection_id, request.clone())
2170        },
2171    );
2172    Ok(())
2173}
2174
2175/// forward a project request to the host. These requests should be read only
2176/// as guests are allowed to send them.
2177async fn forward_read_only_project_request<T>(
2178    request: T,
2179    response: Response<T>,
2180    session: Session,
2181) -> Result<()>
2182where
2183    T: EntityMessage + RequestMessage,
2184{
2185    let project_id = ProjectId::from_proto(request.remote_entity_id());
2186    let host_connection_id = session
2187        .db()
2188        .await
2189        .host_for_read_only_project_request(project_id, session.connection_id)
2190        .await?;
2191    let payload = session
2192        .peer
2193        .forward_request(session.connection_id, host_connection_id, request)
2194        .await?;
2195    response.send(payload)?;
2196    Ok(())
2197}
2198
2199async fn forward_find_search_candidates_request(
2200    request: proto::FindSearchCandidates,
2201    response: Response<proto::FindSearchCandidates>,
2202    session: Session,
2203) -> Result<()> {
2204    let project_id = ProjectId::from_proto(request.remote_entity_id());
2205    let host_connection_id = session
2206        .db()
2207        .await
2208        .host_for_read_only_project_request(project_id, session.connection_id)
2209        .await?;
2210    let payload = session
2211        .peer
2212        .forward_request(session.connection_id, host_connection_id, request)
2213        .await?;
2214    response.send(payload)?;
2215    Ok(())
2216}
2217
2218/// forward a project request to the host. These requests are disallowed
2219/// for guests.
2220async fn forward_mutating_project_request<T>(
2221    request: T,
2222    response: Response<T>,
2223    session: Session,
2224) -> Result<()>
2225where
2226    T: EntityMessage + RequestMessage,
2227{
2228    let project_id = ProjectId::from_proto(request.remote_entity_id());
2229
2230    let host_connection_id = session
2231        .db()
2232        .await
2233        .host_for_mutating_project_request(project_id, session.connection_id)
2234        .await?;
2235    let payload = session
2236        .peer
2237        .forward_request(session.connection_id, host_connection_id, request)
2238        .await?;
2239    response.send(payload)?;
2240    Ok(())
2241}
2242
2243/// Notify other participants that a new buffer has been created
2244async fn create_buffer_for_peer(
2245    request: proto::CreateBufferForPeer,
2246    session: Session,
2247) -> Result<()> {
2248    session
2249        .db()
2250        .await
2251        .check_user_is_project_host(
2252            ProjectId::from_proto(request.project_id),
2253            session.connection_id,
2254        )
2255        .await?;
2256    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2257    session
2258        .peer
2259        .forward_send(session.connection_id, peer_id.into(), request)?;
2260    Ok(())
2261}
2262
2263/// Notify other participants that a buffer has been updated. This is
2264/// allowed for guests as long as the update is limited to selections.
2265async fn update_buffer(
2266    request: proto::UpdateBuffer,
2267    response: Response<proto::UpdateBuffer>,
2268    session: Session,
2269) -> Result<()> {
2270    let project_id = ProjectId::from_proto(request.project_id);
2271    let mut capability = Capability::ReadOnly;
2272
2273    for op in request.operations.iter() {
2274        match op.variant {
2275            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2276            Some(_) => capability = Capability::ReadWrite,
2277        }
2278    }
2279
2280    let host = {
2281        let guard = session
2282            .db()
2283            .await
2284            .connections_for_buffer_update(project_id, session.connection_id, capability)
2285            .await?;
2286
2287        let (host, guests) = &*guard;
2288
2289        broadcast(
2290            Some(session.connection_id),
2291            guests.clone(),
2292            |connection_id| {
2293                session
2294                    .peer
2295                    .forward_send(session.connection_id, connection_id, request.clone())
2296            },
2297        );
2298
2299        *host
2300    };
2301
2302    if host != session.connection_id {
2303        session
2304            .peer
2305            .forward_request(session.connection_id, host, request.clone())
2306            .await?;
2307    }
2308
2309    response.send(proto::Ack {})?;
2310    Ok(())
2311}
2312
2313async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2314    let project_id = ProjectId::from_proto(message.project_id);
2315
2316    let operation = message.operation.as_ref().context("invalid operation")?;
2317    let capability = match operation.variant.as_ref() {
2318        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2319            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2320                match buffer_op.variant {
2321                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2322                        Capability::ReadOnly
2323                    }
2324                    _ => Capability::ReadWrite,
2325                }
2326            } else {
2327                Capability::ReadWrite
2328            }
2329        }
2330        Some(_) => Capability::ReadWrite,
2331        None => Capability::ReadOnly,
2332    };
2333
2334    let guard = session
2335        .db()
2336        .await
2337        .connections_for_buffer_update(project_id, session.connection_id, capability)
2338        .await?;
2339
2340    let (host, guests) = &*guard;
2341
2342    broadcast(
2343        Some(session.connection_id),
2344        guests.iter().chain([host]).copied(),
2345        |connection_id| {
2346            session
2347                .peer
2348                .forward_send(session.connection_id, connection_id, message.clone())
2349        },
2350    );
2351
2352    Ok(())
2353}
2354
2355/// Notify other participants that a project has been updated.
2356async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2357    request: T,
2358    session: Session,
2359) -> Result<()> {
2360    let project_id = ProjectId::from_proto(request.remote_entity_id());
2361    let project_connection_ids = session
2362        .db()
2363        .await
2364        .project_connection_ids(project_id, session.connection_id, false)
2365        .await?;
2366
2367    broadcast(
2368        Some(session.connection_id),
2369        project_connection_ids.iter().copied(),
2370        |connection_id| {
2371            session
2372                .peer
2373                .forward_send(session.connection_id, connection_id, request.clone())
2374        },
2375    );
2376    Ok(())
2377}
2378
2379/// Start following another user in a call.
2380async fn follow(
2381    request: proto::Follow,
2382    response: Response<proto::Follow>,
2383    session: Session,
2384) -> Result<()> {
2385    let room_id = RoomId::from_proto(request.room_id);
2386    let project_id = request.project_id.map(ProjectId::from_proto);
2387    let leader_id = request
2388        .leader_id
2389        .ok_or_else(|| anyhow!("invalid leader id"))?
2390        .into();
2391    let follower_id = session.connection_id;
2392
2393    session
2394        .db()
2395        .await
2396        .check_room_participants(room_id, leader_id, session.connection_id)
2397        .await?;
2398
2399    let response_payload = session
2400        .peer
2401        .forward_request(session.connection_id, leader_id, request)
2402        .await?;
2403    response.send(response_payload)?;
2404
2405    if let Some(project_id) = project_id {
2406        let room = session
2407            .db()
2408            .await
2409            .follow(room_id, project_id, leader_id, follower_id)
2410            .await?;
2411        room_updated(&room, &session.peer);
2412    }
2413
2414    Ok(())
2415}
2416
2417/// Stop following another user in a call.
2418async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2419    let room_id = RoomId::from_proto(request.room_id);
2420    let project_id = request.project_id.map(ProjectId::from_proto);
2421    let leader_id = request
2422        .leader_id
2423        .ok_or_else(|| anyhow!("invalid leader id"))?
2424        .into();
2425    let follower_id = session.connection_id;
2426
2427    session
2428        .db()
2429        .await
2430        .check_room_participants(room_id, leader_id, session.connection_id)
2431        .await?;
2432
2433    session
2434        .peer
2435        .forward_send(session.connection_id, leader_id, request)?;
2436
2437    if let Some(project_id) = project_id {
2438        let room = session
2439            .db()
2440            .await
2441            .unfollow(room_id, project_id, leader_id, follower_id)
2442            .await?;
2443        room_updated(&room, &session.peer);
2444    }
2445
2446    Ok(())
2447}
2448
2449/// Notify everyone following you of your current location.
2450async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2451    let room_id = RoomId::from_proto(request.room_id);
2452    let database = session.db.lock().await;
2453
2454    let connection_ids = if let Some(project_id) = request.project_id {
2455        let project_id = ProjectId::from_proto(project_id);
2456        database
2457            .project_connection_ids(project_id, session.connection_id, true)
2458            .await?
2459    } else {
2460        database
2461            .room_connection_ids(room_id, session.connection_id)
2462            .await?
2463    };
2464
2465    // For now, don't send view update messages back to that view's current leader.
2466    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2467        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2468        _ => None,
2469    });
2470
2471    for connection_id in connection_ids.iter().cloned() {
2472        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2473            session
2474                .peer
2475                .forward_send(session.connection_id, connection_id, request.clone())?;
2476        }
2477    }
2478    Ok(())
2479}
2480
2481/// Get public data about users.
2482async fn get_users(
2483    request: proto::GetUsers,
2484    response: Response<proto::GetUsers>,
2485    session: Session,
2486) -> Result<()> {
2487    let user_ids = request
2488        .user_ids
2489        .into_iter()
2490        .map(UserId::from_proto)
2491        .collect();
2492    let users = session
2493        .db()
2494        .await
2495        .get_users_by_ids(user_ids)
2496        .await?
2497        .into_iter()
2498        .map(|user| proto::User {
2499            id: user.id.to_proto(),
2500            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2501            github_login: user.github_login,
2502            email: user.email_address,
2503            name: user.name,
2504        })
2505        .collect();
2506    response.send(proto::UsersResponse { users })?;
2507    Ok(())
2508}
2509
2510/// Search for users (to invite) buy Github login
2511async fn fuzzy_search_users(
2512    request: proto::FuzzySearchUsers,
2513    response: Response<proto::FuzzySearchUsers>,
2514    session: Session,
2515) -> Result<()> {
2516    let query = request.query;
2517    let users = match query.len() {
2518        0 => vec![],
2519        1 | 2 => session
2520            .db()
2521            .await
2522            .get_user_by_github_login(&query)
2523            .await?
2524            .into_iter()
2525            .collect(),
2526        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2527    };
2528    let users = users
2529        .into_iter()
2530        .filter(|user| user.id != session.user_id())
2531        .map(|user| proto::User {
2532            id: user.id.to_proto(),
2533            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2534            github_login: user.github_login,
2535            name: user.name,
2536            email: user.email_address,
2537        })
2538        .collect();
2539    response.send(proto::UsersResponse { users })?;
2540    Ok(())
2541}
2542
2543/// Send a contact request to another user.
2544async fn request_contact(
2545    request: proto::RequestContact,
2546    response: Response<proto::RequestContact>,
2547    session: Session,
2548) -> Result<()> {
2549    let requester_id = session.user_id();
2550    let responder_id = UserId::from_proto(request.responder_id);
2551    if requester_id == responder_id {
2552        return Err(anyhow!("cannot add yourself as a contact"))?;
2553    }
2554
2555    let notifications = session
2556        .db()
2557        .await
2558        .send_contact_request(requester_id, responder_id)
2559        .await?;
2560
2561    // Update outgoing contact requests of requester
2562    let mut update = proto::UpdateContacts::default();
2563    update.outgoing_requests.push(responder_id.to_proto());
2564    for connection_id in session
2565        .connection_pool()
2566        .await
2567        .user_connection_ids(requester_id)
2568    {
2569        session.peer.send(connection_id, update.clone())?;
2570    }
2571
2572    // Update incoming contact requests of responder
2573    let mut update = proto::UpdateContacts::default();
2574    update
2575        .incoming_requests
2576        .push(proto::IncomingContactRequest {
2577            requester_id: requester_id.to_proto(),
2578        });
2579    let connection_pool = session.connection_pool().await;
2580    for connection_id in connection_pool.user_connection_ids(responder_id) {
2581        session.peer.send(connection_id, update.clone())?;
2582    }
2583
2584    send_notifications(&connection_pool, &session.peer, notifications);
2585
2586    response.send(proto::Ack {})?;
2587    Ok(())
2588}
2589
2590/// Accept or decline a contact request
2591async fn respond_to_contact_request(
2592    request: proto::RespondToContactRequest,
2593    response: Response<proto::RespondToContactRequest>,
2594    session: Session,
2595) -> Result<()> {
2596    let responder_id = session.user_id();
2597    let requester_id = UserId::from_proto(request.requester_id);
2598    let db = session.db().await;
2599    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2600        db.dismiss_contact_notification(responder_id, requester_id)
2601            .await?;
2602    } else {
2603        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2604
2605        let notifications = db
2606            .respond_to_contact_request(responder_id, requester_id, accept)
2607            .await?;
2608        let requester_busy = db.is_user_busy(requester_id).await?;
2609        let responder_busy = db.is_user_busy(responder_id).await?;
2610
2611        let pool = session.connection_pool().await;
2612        // Update responder with new contact
2613        let mut update = proto::UpdateContacts::default();
2614        if accept {
2615            update
2616                .contacts
2617                .push(contact_for_user(requester_id, requester_busy, &pool));
2618        }
2619        update
2620            .remove_incoming_requests
2621            .push(requester_id.to_proto());
2622        for connection_id in pool.user_connection_ids(responder_id) {
2623            session.peer.send(connection_id, update.clone())?;
2624        }
2625
2626        // Update requester with new contact
2627        let mut update = proto::UpdateContacts::default();
2628        if accept {
2629            update
2630                .contacts
2631                .push(contact_for_user(responder_id, responder_busy, &pool));
2632        }
2633        update
2634            .remove_outgoing_requests
2635            .push(responder_id.to_proto());
2636
2637        for connection_id in pool.user_connection_ids(requester_id) {
2638            session.peer.send(connection_id, update.clone())?;
2639        }
2640
2641        send_notifications(&pool, &session.peer, notifications);
2642    }
2643
2644    response.send(proto::Ack {})?;
2645    Ok(())
2646}
2647
2648/// Remove a contact.
2649async fn remove_contact(
2650    request: proto::RemoveContact,
2651    response: Response<proto::RemoveContact>,
2652    session: Session,
2653) -> Result<()> {
2654    let requester_id = session.user_id();
2655    let responder_id = UserId::from_proto(request.user_id);
2656    let db = session.db().await;
2657    let (contact_accepted, deleted_notification_id) =
2658        db.remove_contact(requester_id, responder_id).await?;
2659
2660    let pool = session.connection_pool().await;
2661    // Update outgoing contact requests of requester
2662    let mut update = proto::UpdateContacts::default();
2663    if contact_accepted {
2664        update.remove_contacts.push(responder_id.to_proto());
2665    } else {
2666        update
2667            .remove_outgoing_requests
2668            .push(responder_id.to_proto());
2669    }
2670    for connection_id in pool.user_connection_ids(requester_id) {
2671        session.peer.send(connection_id, update.clone())?;
2672    }
2673
2674    // Update incoming contact requests of responder
2675    let mut update = proto::UpdateContacts::default();
2676    if contact_accepted {
2677        update.remove_contacts.push(requester_id.to_proto());
2678    } else {
2679        update
2680            .remove_incoming_requests
2681            .push(requester_id.to_proto());
2682    }
2683    for connection_id in pool.user_connection_ids(responder_id) {
2684        session.peer.send(connection_id, update.clone())?;
2685        if let Some(notification_id) = deleted_notification_id {
2686            session.peer.send(
2687                connection_id,
2688                proto::DeleteNotification {
2689                    notification_id: notification_id.to_proto(),
2690                },
2691            )?;
2692        }
2693    }
2694
2695    response.send(proto::Ack {})?;
2696    Ok(())
2697}
2698
2699fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2700    version.0.minor() < 139
2701}
2702
2703async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2704    let plan = session.current_plan(&session.db().await).await?;
2705
2706    session
2707        .peer
2708        .send(
2709            session.connection_id,
2710            proto::UpdateUserPlan { plan: plan.into() },
2711        )
2712        .trace_err();
2713
2714    Ok(())
2715}
2716
2717async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2718    subscribe_user_to_channels(session.user_id(), &session).await?;
2719    Ok(())
2720}
2721
2722async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2723    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2724    let mut pool = session.connection_pool().await;
2725    for membership in &channels_for_user.channel_memberships {
2726        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2727    }
2728    session.peer.send(
2729        session.connection_id,
2730        build_update_user_channels(&channels_for_user),
2731    )?;
2732    session.peer.send(
2733        session.connection_id,
2734        build_channels_update(channels_for_user),
2735    )?;
2736    Ok(())
2737}
2738
2739/// Creates a new channel.
2740async fn create_channel(
2741    request: proto::CreateChannel,
2742    response: Response<proto::CreateChannel>,
2743    session: Session,
2744) -> Result<()> {
2745    let db = session.db().await;
2746
2747    let parent_id = request.parent_id.map(ChannelId::from_proto);
2748    let (channel, membership) = db
2749        .create_channel(&request.name, parent_id, session.user_id())
2750        .await?;
2751
2752    let root_id = channel.root_id();
2753    let channel = Channel::from_model(channel);
2754
2755    response.send(proto::CreateChannelResponse {
2756        channel: Some(channel.to_proto()),
2757        parent_id: request.parent_id,
2758    })?;
2759
2760    let mut connection_pool = session.connection_pool().await;
2761    if let Some(membership) = membership {
2762        connection_pool.subscribe_to_channel(
2763            membership.user_id,
2764            membership.channel_id,
2765            membership.role,
2766        );
2767        let update = proto::UpdateUserChannels {
2768            channel_memberships: vec![proto::ChannelMembership {
2769                channel_id: membership.channel_id.to_proto(),
2770                role: membership.role.into(),
2771            }],
2772            ..Default::default()
2773        };
2774        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2775            session.peer.send(connection_id, update.clone())?;
2776        }
2777    }
2778
2779    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2780        if !role.can_see_channel(channel.visibility) {
2781            continue;
2782        }
2783
2784        let update = proto::UpdateChannels {
2785            channels: vec![channel.to_proto()],
2786            ..Default::default()
2787        };
2788        session.peer.send(connection_id, update.clone())?;
2789    }
2790
2791    Ok(())
2792}
2793
2794/// Delete a channel
2795async fn delete_channel(
2796    request: proto::DeleteChannel,
2797    response: Response<proto::DeleteChannel>,
2798    session: Session,
2799) -> Result<()> {
2800    let db = session.db().await;
2801
2802    let channel_id = request.channel_id;
2803    let (root_channel, removed_channels) = db
2804        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2805        .await?;
2806    response.send(proto::Ack {})?;
2807
2808    // Notify members of removed channels
2809    let mut update = proto::UpdateChannels::default();
2810    update
2811        .delete_channels
2812        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2813
2814    let connection_pool = session.connection_pool().await;
2815    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2816        session.peer.send(connection_id, update.clone())?;
2817    }
2818
2819    Ok(())
2820}
2821
2822/// Invite someone to join a channel.
2823async fn invite_channel_member(
2824    request: proto::InviteChannelMember,
2825    response: Response<proto::InviteChannelMember>,
2826    session: Session,
2827) -> Result<()> {
2828    let db = session.db().await;
2829    let channel_id = ChannelId::from_proto(request.channel_id);
2830    let invitee_id = UserId::from_proto(request.user_id);
2831    let InviteMemberResult {
2832        channel,
2833        notifications,
2834    } = db
2835        .invite_channel_member(
2836            channel_id,
2837            invitee_id,
2838            session.user_id(),
2839            request.role().into(),
2840        )
2841        .await?;
2842
2843    let update = proto::UpdateChannels {
2844        channel_invitations: vec![channel.to_proto()],
2845        ..Default::default()
2846    };
2847
2848    let connection_pool = session.connection_pool().await;
2849    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2850        session.peer.send(connection_id, update.clone())?;
2851    }
2852
2853    send_notifications(&connection_pool, &session.peer, notifications);
2854
2855    response.send(proto::Ack {})?;
2856    Ok(())
2857}
2858
2859/// remove someone from a channel
2860async fn remove_channel_member(
2861    request: proto::RemoveChannelMember,
2862    response: Response<proto::RemoveChannelMember>,
2863    session: Session,
2864) -> Result<()> {
2865    let db = session.db().await;
2866    let channel_id = ChannelId::from_proto(request.channel_id);
2867    let member_id = UserId::from_proto(request.user_id);
2868
2869    let RemoveChannelMemberResult {
2870        membership_update,
2871        notification_id,
2872    } = db
2873        .remove_channel_member(channel_id, member_id, session.user_id())
2874        .await?;
2875
2876    let mut connection_pool = session.connection_pool().await;
2877    notify_membership_updated(
2878        &mut connection_pool,
2879        membership_update,
2880        member_id,
2881        &session.peer,
2882    );
2883    for connection_id in connection_pool.user_connection_ids(member_id) {
2884        if let Some(notification_id) = notification_id {
2885            session
2886                .peer
2887                .send(
2888                    connection_id,
2889                    proto::DeleteNotification {
2890                        notification_id: notification_id.to_proto(),
2891                    },
2892                )
2893                .trace_err();
2894        }
2895    }
2896
2897    response.send(proto::Ack {})?;
2898    Ok(())
2899}
2900
2901/// Toggle the channel between public and private.
2902/// Care is taken to maintain the invariant that public channels only descend from public channels,
2903/// (though members-only channels can appear at any point in the hierarchy).
2904async fn set_channel_visibility(
2905    request: proto::SetChannelVisibility,
2906    response: Response<proto::SetChannelVisibility>,
2907    session: Session,
2908) -> Result<()> {
2909    let db = session.db().await;
2910    let channel_id = ChannelId::from_proto(request.channel_id);
2911    let visibility = request.visibility().into();
2912
2913    let channel_model = db
2914        .set_channel_visibility(channel_id, visibility, session.user_id())
2915        .await?;
2916    let root_id = channel_model.root_id();
2917    let channel = Channel::from_model(channel_model);
2918
2919    let mut connection_pool = session.connection_pool().await;
2920    for (user_id, role) in connection_pool
2921        .channel_user_ids(root_id)
2922        .collect::<Vec<_>>()
2923        .into_iter()
2924    {
2925        let update = if role.can_see_channel(channel.visibility) {
2926            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2927            proto::UpdateChannels {
2928                channels: vec![channel.to_proto()],
2929                ..Default::default()
2930            }
2931        } else {
2932            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2933            proto::UpdateChannels {
2934                delete_channels: vec![channel.id.to_proto()],
2935                ..Default::default()
2936            }
2937        };
2938
2939        for connection_id in connection_pool.user_connection_ids(user_id) {
2940            session.peer.send(connection_id, update.clone())?;
2941        }
2942    }
2943
2944    response.send(proto::Ack {})?;
2945    Ok(())
2946}
2947
2948/// Alter the role for a user in the channel.
2949async fn set_channel_member_role(
2950    request: proto::SetChannelMemberRole,
2951    response: Response<proto::SetChannelMemberRole>,
2952    session: Session,
2953) -> Result<()> {
2954    let db = session.db().await;
2955    let channel_id = ChannelId::from_proto(request.channel_id);
2956    let member_id = UserId::from_proto(request.user_id);
2957    let result = db
2958        .set_channel_member_role(
2959            channel_id,
2960            session.user_id(),
2961            member_id,
2962            request.role().into(),
2963        )
2964        .await?;
2965
2966    match result {
2967        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2968            let mut connection_pool = session.connection_pool().await;
2969            notify_membership_updated(
2970                &mut connection_pool,
2971                membership_update,
2972                member_id,
2973                &session.peer,
2974            )
2975        }
2976        db::SetMemberRoleResult::InviteUpdated(channel) => {
2977            let update = proto::UpdateChannels {
2978                channel_invitations: vec![channel.to_proto()],
2979                ..Default::default()
2980            };
2981
2982            for connection_id in session
2983                .connection_pool()
2984                .await
2985                .user_connection_ids(member_id)
2986            {
2987                session.peer.send(connection_id, update.clone())?;
2988            }
2989        }
2990    }
2991
2992    response.send(proto::Ack {})?;
2993    Ok(())
2994}
2995
2996/// Change the name of a channel
2997async fn rename_channel(
2998    request: proto::RenameChannel,
2999    response: Response<proto::RenameChannel>,
3000    session: Session,
3001) -> Result<()> {
3002    let db = session.db().await;
3003    let channel_id = ChannelId::from_proto(request.channel_id);
3004    let channel_model = db
3005        .rename_channel(channel_id, session.user_id(), &request.name)
3006        .await?;
3007    let root_id = channel_model.root_id();
3008    let channel = Channel::from_model(channel_model);
3009
3010    response.send(proto::RenameChannelResponse {
3011        channel: Some(channel.to_proto()),
3012    })?;
3013
3014    let connection_pool = session.connection_pool().await;
3015    let update = proto::UpdateChannels {
3016        channels: vec![channel.to_proto()],
3017        ..Default::default()
3018    };
3019    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3020        if role.can_see_channel(channel.visibility) {
3021            session.peer.send(connection_id, update.clone())?;
3022        }
3023    }
3024
3025    Ok(())
3026}
3027
3028/// Move a channel to a new parent.
3029async fn move_channel(
3030    request: proto::MoveChannel,
3031    response: Response<proto::MoveChannel>,
3032    session: Session,
3033) -> Result<()> {
3034    let channel_id = ChannelId::from_proto(request.channel_id);
3035    let to = ChannelId::from_proto(request.to);
3036
3037    let (root_id, channels) = session
3038        .db()
3039        .await
3040        .move_channel(channel_id, to, session.user_id())
3041        .await?;
3042
3043    let connection_pool = session.connection_pool().await;
3044    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3045        let channels = channels
3046            .iter()
3047            .filter_map(|channel| {
3048                if role.can_see_channel(channel.visibility) {
3049                    Some(channel.to_proto())
3050                } else {
3051                    None
3052                }
3053            })
3054            .collect::<Vec<_>>();
3055        if channels.is_empty() {
3056            continue;
3057        }
3058
3059        let update = proto::UpdateChannels {
3060            channels,
3061            ..Default::default()
3062        };
3063
3064        session.peer.send(connection_id, update.clone())?;
3065    }
3066
3067    response.send(Ack {})?;
3068    Ok(())
3069}
3070
3071/// Get the list of channel members
3072async fn get_channel_members(
3073    request: proto::GetChannelMembers,
3074    response: Response<proto::GetChannelMembers>,
3075    session: Session,
3076) -> Result<()> {
3077    let db = session.db().await;
3078    let channel_id = ChannelId::from_proto(request.channel_id);
3079    let limit = if request.limit == 0 {
3080        u16::MAX as u64
3081    } else {
3082        request.limit
3083    };
3084    let (members, users) = db
3085        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3086        .await?;
3087    response.send(proto::GetChannelMembersResponse { members, users })?;
3088    Ok(())
3089}
3090
3091/// Accept or decline a channel invitation.
3092async fn respond_to_channel_invite(
3093    request: proto::RespondToChannelInvite,
3094    response: Response<proto::RespondToChannelInvite>,
3095    session: Session,
3096) -> Result<()> {
3097    let db = session.db().await;
3098    let channel_id = ChannelId::from_proto(request.channel_id);
3099    let RespondToChannelInvite {
3100        membership_update,
3101        notifications,
3102    } = db
3103        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3104        .await?;
3105
3106    let mut connection_pool = session.connection_pool().await;
3107    if let Some(membership_update) = membership_update {
3108        notify_membership_updated(
3109            &mut connection_pool,
3110            membership_update,
3111            session.user_id(),
3112            &session.peer,
3113        );
3114    } else {
3115        let update = proto::UpdateChannels {
3116            remove_channel_invitations: vec![channel_id.to_proto()],
3117            ..Default::default()
3118        };
3119
3120        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3121            session.peer.send(connection_id, update.clone())?;
3122        }
3123    };
3124
3125    send_notifications(&connection_pool, &session.peer, notifications);
3126
3127    response.send(proto::Ack {})?;
3128
3129    Ok(())
3130}
3131
3132/// Join the channels' room
3133async fn join_channel(
3134    request: proto::JoinChannel,
3135    response: Response<proto::JoinChannel>,
3136    session: Session,
3137) -> Result<()> {
3138    let channel_id = ChannelId::from_proto(request.channel_id);
3139    join_channel_internal(channel_id, Box::new(response), session).await
3140}
3141
3142trait JoinChannelInternalResponse {
3143    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3144}
3145impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3146    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3147        Response::<proto::JoinChannel>::send(self, result)
3148    }
3149}
3150impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3151    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3152        Response::<proto::JoinRoom>::send(self, result)
3153    }
3154}
3155
3156async fn join_channel_internal(
3157    channel_id: ChannelId,
3158    response: Box<impl JoinChannelInternalResponse>,
3159    session: Session,
3160) -> Result<()> {
3161    let joined_room = {
3162        let mut db = session.db().await;
3163        // If zed quits without leaving the room, and the user re-opens zed before the
3164        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3165        // room they were in.
3166        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3167            tracing::info!(
3168                stale_connection_id = %connection,
3169                "cleaning up stale connection",
3170            );
3171            drop(db);
3172            leave_room_for_session(&session, connection).await?;
3173            db = session.db().await;
3174        }
3175
3176        let (joined_room, membership_updated, role) = db
3177            .join_channel(channel_id, session.user_id(), session.connection_id)
3178            .await?;
3179
3180        let live_kit_connection_info =
3181            session
3182                .app_state
3183                .livekit_client
3184                .as_ref()
3185                .and_then(|live_kit| {
3186                    let (can_publish, token) = if role == ChannelRole::Guest {
3187                        (
3188                            false,
3189                            live_kit
3190                                .guest_token(
3191                                    &joined_room.room.livekit_room,
3192                                    &session.user_id().to_string(),
3193                                )
3194                                .trace_err()?,
3195                        )
3196                    } else {
3197                        (
3198                            true,
3199                            live_kit
3200                                .room_token(
3201                                    &joined_room.room.livekit_room,
3202                                    &session.user_id().to_string(),
3203                                )
3204                                .trace_err()?,
3205                        )
3206                    };
3207
3208                    Some(LiveKitConnectionInfo {
3209                        server_url: live_kit.url().into(),
3210                        token,
3211                        can_publish,
3212                    })
3213                });
3214
3215        response.send(proto::JoinRoomResponse {
3216            room: Some(joined_room.room.clone()),
3217            channel_id: joined_room
3218                .channel
3219                .as_ref()
3220                .map(|channel| channel.id.to_proto()),
3221            live_kit_connection_info,
3222        })?;
3223
3224        let mut connection_pool = session.connection_pool().await;
3225        if let Some(membership_updated) = membership_updated {
3226            notify_membership_updated(
3227                &mut connection_pool,
3228                membership_updated,
3229                session.user_id(),
3230                &session.peer,
3231            );
3232        }
3233
3234        room_updated(&joined_room.room, &session.peer);
3235
3236        joined_room
3237    };
3238
3239    channel_updated(
3240        &joined_room
3241            .channel
3242            .ok_or_else(|| anyhow!("channel not returned"))?,
3243        &joined_room.room,
3244        &session.peer,
3245        &*session.connection_pool().await,
3246    );
3247
3248    update_user_contacts(session.user_id(), &session).await?;
3249    Ok(())
3250}
3251
3252/// Start editing the channel notes
3253async fn join_channel_buffer(
3254    request: proto::JoinChannelBuffer,
3255    response: Response<proto::JoinChannelBuffer>,
3256    session: Session,
3257) -> Result<()> {
3258    let db = session.db().await;
3259    let channel_id = ChannelId::from_proto(request.channel_id);
3260
3261    let open_response = db
3262        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3263        .await?;
3264
3265    let collaborators = open_response.collaborators.clone();
3266    response.send(open_response)?;
3267
3268    let update = UpdateChannelBufferCollaborators {
3269        channel_id: channel_id.to_proto(),
3270        collaborators: collaborators.clone(),
3271    };
3272    channel_buffer_updated(
3273        session.connection_id,
3274        collaborators
3275            .iter()
3276            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3277        &update,
3278        &session.peer,
3279    );
3280
3281    Ok(())
3282}
3283
3284/// Edit the channel notes
3285async fn update_channel_buffer(
3286    request: proto::UpdateChannelBuffer,
3287    session: Session,
3288) -> Result<()> {
3289    let db = session.db().await;
3290    let channel_id = ChannelId::from_proto(request.channel_id);
3291
3292    let (collaborators, epoch, version) = db
3293        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3294        .await?;
3295
3296    channel_buffer_updated(
3297        session.connection_id,
3298        collaborators.clone(),
3299        &proto::UpdateChannelBuffer {
3300            channel_id: channel_id.to_proto(),
3301            operations: request.operations,
3302        },
3303        &session.peer,
3304    );
3305
3306    let pool = &*session.connection_pool().await;
3307
3308    let non_collaborators =
3309        pool.channel_connection_ids(channel_id)
3310            .filter_map(|(connection_id, _)| {
3311                if collaborators.contains(&connection_id) {
3312                    None
3313                } else {
3314                    Some(connection_id)
3315                }
3316            });
3317
3318    broadcast(None, non_collaborators, |peer_id| {
3319        session.peer.send(
3320            peer_id,
3321            proto::UpdateChannels {
3322                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3323                    channel_id: channel_id.to_proto(),
3324                    epoch: epoch as u64,
3325                    version: version.clone(),
3326                }],
3327                ..Default::default()
3328            },
3329        )
3330    });
3331
3332    Ok(())
3333}
3334
3335/// Rejoin the channel notes after a connection blip
3336async fn rejoin_channel_buffers(
3337    request: proto::RejoinChannelBuffers,
3338    response: Response<proto::RejoinChannelBuffers>,
3339    session: Session,
3340) -> Result<()> {
3341    let db = session.db().await;
3342    let buffers = db
3343        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3344        .await?;
3345
3346    for rejoined_buffer in &buffers {
3347        let collaborators_to_notify = rejoined_buffer
3348            .buffer
3349            .collaborators
3350            .iter()
3351            .filter_map(|c| Some(c.peer_id?.into()));
3352        channel_buffer_updated(
3353            session.connection_id,
3354            collaborators_to_notify,
3355            &proto::UpdateChannelBufferCollaborators {
3356                channel_id: rejoined_buffer.buffer.channel_id,
3357                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3358            },
3359            &session.peer,
3360        );
3361    }
3362
3363    response.send(proto::RejoinChannelBuffersResponse {
3364        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3365    })?;
3366
3367    Ok(())
3368}
3369
3370/// Stop editing the channel notes
3371async fn leave_channel_buffer(
3372    request: proto::LeaveChannelBuffer,
3373    response: Response<proto::LeaveChannelBuffer>,
3374    session: Session,
3375) -> Result<()> {
3376    let db = session.db().await;
3377    let channel_id = ChannelId::from_proto(request.channel_id);
3378
3379    let left_buffer = db
3380        .leave_channel_buffer(channel_id, session.connection_id)
3381        .await?;
3382
3383    response.send(Ack {})?;
3384
3385    channel_buffer_updated(
3386        session.connection_id,
3387        left_buffer.connections,
3388        &proto::UpdateChannelBufferCollaborators {
3389            channel_id: channel_id.to_proto(),
3390            collaborators: left_buffer.collaborators,
3391        },
3392        &session.peer,
3393    );
3394
3395    Ok(())
3396}
3397
3398fn channel_buffer_updated<T: EnvelopedMessage>(
3399    sender_id: ConnectionId,
3400    collaborators: impl IntoIterator<Item = ConnectionId>,
3401    message: &T,
3402    peer: &Peer,
3403) {
3404    broadcast(Some(sender_id), collaborators, |peer_id| {
3405        peer.send(peer_id, message.clone())
3406    });
3407}
3408
3409fn send_notifications(
3410    connection_pool: &ConnectionPool,
3411    peer: &Peer,
3412    notifications: db::NotificationBatch,
3413) {
3414    for (user_id, notification) in notifications {
3415        for connection_id in connection_pool.user_connection_ids(user_id) {
3416            if let Err(error) = peer.send(
3417                connection_id,
3418                proto::AddNotification {
3419                    notification: Some(notification.clone()),
3420                },
3421            ) {
3422                tracing::error!(
3423                    "failed to send notification to {:?} {}",
3424                    connection_id,
3425                    error
3426                );
3427            }
3428        }
3429    }
3430}
3431
3432/// Send a message to the channel
3433async fn send_channel_message(
3434    request: proto::SendChannelMessage,
3435    response: Response<proto::SendChannelMessage>,
3436    session: Session,
3437) -> Result<()> {
3438    // Validate the message body.
3439    let body = request.body.trim().to_string();
3440    if body.len() > MAX_MESSAGE_LEN {
3441        return Err(anyhow!("message is too long"))?;
3442    }
3443    if body.is_empty() {
3444        return Err(anyhow!("message can't be blank"))?;
3445    }
3446
3447    // TODO: adjust mentions if body is trimmed
3448
3449    let timestamp = OffsetDateTime::now_utc();
3450    let nonce = request
3451        .nonce
3452        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3453
3454    let channel_id = ChannelId::from_proto(request.channel_id);
3455    let CreatedChannelMessage {
3456        message_id,
3457        participant_connection_ids,
3458        notifications,
3459    } = session
3460        .db()
3461        .await
3462        .create_channel_message(
3463            channel_id,
3464            session.user_id(),
3465            &body,
3466            &request.mentions,
3467            timestamp,
3468            nonce.clone().into(),
3469            request.reply_to_message_id.map(MessageId::from_proto),
3470        )
3471        .await?;
3472
3473    let message = proto::ChannelMessage {
3474        sender_id: session.user_id().to_proto(),
3475        id: message_id.to_proto(),
3476        body,
3477        mentions: request.mentions,
3478        timestamp: timestamp.unix_timestamp() as u64,
3479        nonce: Some(nonce),
3480        reply_to_message_id: request.reply_to_message_id,
3481        edited_at: None,
3482    };
3483    broadcast(
3484        Some(session.connection_id),
3485        participant_connection_ids.clone(),
3486        |connection| {
3487            session.peer.send(
3488                connection,
3489                proto::ChannelMessageSent {
3490                    channel_id: channel_id.to_proto(),
3491                    message: Some(message.clone()),
3492                },
3493            )
3494        },
3495    );
3496    response.send(proto::SendChannelMessageResponse {
3497        message: Some(message),
3498    })?;
3499
3500    let pool = &*session.connection_pool().await;
3501    let non_participants =
3502        pool.channel_connection_ids(channel_id)
3503            .filter_map(|(connection_id, _)| {
3504                if participant_connection_ids.contains(&connection_id) {
3505                    None
3506                } else {
3507                    Some(connection_id)
3508                }
3509            });
3510    broadcast(None, non_participants, |peer_id| {
3511        session.peer.send(
3512            peer_id,
3513            proto::UpdateChannels {
3514                latest_channel_message_ids: vec![proto::ChannelMessageId {
3515                    channel_id: channel_id.to_proto(),
3516                    message_id: message_id.to_proto(),
3517                }],
3518                ..Default::default()
3519            },
3520        )
3521    });
3522    send_notifications(pool, &session.peer, notifications);
3523
3524    Ok(())
3525}
3526
3527/// Delete a channel message
3528async fn remove_channel_message(
3529    request: proto::RemoveChannelMessage,
3530    response: Response<proto::RemoveChannelMessage>,
3531    session: Session,
3532) -> Result<()> {
3533    let channel_id = ChannelId::from_proto(request.channel_id);
3534    let message_id = MessageId::from_proto(request.message_id);
3535    let (connection_ids, existing_notification_ids) = session
3536        .db()
3537        .await
3538        .remove_channel_message(channel_id, message_id, session.user_id())
3539        .await?;
3540
3541    broadcast(
3542        Some(session.connection_id),
3543        connection_ids,
3544        move |connection| {
3545            session.peer.send(connection, request.clone())?;
3546
3547            for notification_id in &existing_notification_ids {
3548                session.peer.send(
3549                    connection,
3550                    proto::DeleteNotification {
3551                        notification_id: (*notification_id).to_proto(),
3552                    },
3553                )?;
3554            }
3555
3556            Ok(())
3557        },
3558    );
3559    response.send(proto::Ack {})?;
3560    Ok(())
3561}
3562
3563async fn update_channel_message(
3564    request: proto::UpdateChannelMessage,
3565    response: Response<proto::UpdateChannelMessage>,
3566    session: Session,
3567) -> Result<()> {
3568    let channel_id = ChannelId::from_proto(request.channel_id);
3569    let message_id = MessageId::from_proto(request.message_id);
3570    let updated_at = OffsetDateTime::now_utc();
3571    let UpdatedChannelMessage {
3572        message_id,
3573        participant_connection_ids,
3574        notifications,
3575        reply_to_message_id,
3576        timestamp,
3577        deleted_mention_notification_ids,
3578        updated_mention_notifications,
3579    } = session
3580        .db()
3581        .await
3582        .update_channel_message(
3583            channel_id,
3584            message_id,
3585            session.user_id(),
3586            request.body.as_str(),
3587            &request.mentions,
3588            updated_at,
3589        )
3590        .await?;
3591
3592    let nonce = request
3593        .nonce
3594        .clone()
3595        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3596
3597    let message = proto::ChannelMessage {
3598        sender_id: session.user_id().to_proto(),
3599        id: message_id.to_proto(),
3600        body: request.body.clone(),
3601        mentions: request.mentions.clone(),
3602        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3603        nonce: Some(nonce),
3604        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3605        edited_at: Some(updated_at.unix_timestamp() as u64),
3606    };
3607
3608    response.send(proto::Ack {})?;
3609
3610    let pool = &*session.connection_pool().await;
3611    broadcast(
3612        Some(session.connection_id),
3613        participant_connection_ids,
3614        |connection| {
3615            session.peer.send(
3616                connection,
3617                proto::ChannelMessageUpdate {
3618                    channel_id: channel_id.to_proto(),
3619                    message: Some(message.clone()),
3620                },
3621            )?;
3622
3623            for notification_id in &deleted_mention_notification_ids {
3624                session.peer.send(
3625                    connection,
3626                    proto::DeleteNotification {
3627                        notification_id: (*notification_id).to_proto(),
3628                    },
3629                )?;
3630            }
3631
3632            for notification in &updated_mention_notifications {
3633                session.peer.send(
3634                    connection,
3635                    proto::UpdateNotification {
3636                        notification: Some(notification.clone()),
3637                    },
3638                )?;
3639            }
3640
3641            Ok(())
3642        },
3643    );
3644
3645    send_notifications(pool, &session.peer, notifications);
3646
3647    Ok(())
3648}
3649
3650/// Mark a channel message as read
3651async fn acknowledge_channel_message(
3652    request: proto::AckChannelMessage,
3653    session: Session,
3654) -> Result<()> {
3655    let channel_id = ChannelId::from_proto(request.channel_id);
3656    let message_id = MessageId::from_proto(request.message_id);
3657    let notifications = session
3658        .db()
3659        .await
3660        .observe_channel_message(channel_id, session.user_id(), message_id)
3661        .await?;
3662    send_notifications(
3663        &*session.connection_pool().await,
3664        &session.peer,
3665        notifications,
3666    );
3667    Ok(())
3668}
3669
3670/// Mark a buffer version as synced
3671async fn acknowledge_buffer_version(
3672    request: proto::AckBufferOperation,
3673    session: Session,
3674) -> Result<()> {
3675    let buffer_id = BufferId::from_proto(request.buffer_id);
3676    session
3677        .db()
3678        .await
3679        .observe_buffer_version(
3680            buffer_id,
3681            session.user_id(),
3682            request.epoch as i32,
3683            &request.version,
3684        )
3685        .await?;
3686    Ok(())
3687}
3688
3689async fn count_language_model_tokens(
3690    request: proto::CountLanguageModelTokens,
3691    response: Response<proto::CountLanguageModelTokens>,
3692    session: Session,
3693    config: &Config,
3694) -> Result<()> {
3695    authorize_access_to_legacy_llm_endpoints(&session).await?;
3696
3697    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3698        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3699        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3700    };
3701
3702    session
3703        .app_state
3704        .rate_limiter
3705        .check(&*rate_limit, session.user_id())
3706        .await?;
3707
3708    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3709        Some(proto::LanguageModelProvider::Google) => {
3710            let api_key = config
3711                .google_ai_api_key
3712                .as_ref()
3713                .context("no Google AI API key configured on the server")?;
3714            google_ai::count_tokens(
3715                session.http_client.as_ref(),
3716                google_ai::API_URL,
3717                api_key,
3718                serde_json::from_str(&request.request)?,
3719            )
3720            .await?
3721        }
3722        _ => return Err(anyhow!("unsupported provider"))?,
3723    };
3724
3725    response.send(proto::CountLanguageModelTokensResponse {
3726        token_count: result.total_tokens as u32,
3727    })?;
3728
3729    Ok(())
3730}
3731
3732struct ZedProCountLanguageModelTokensRateLimit;
3733
3734impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3735    fn capacity(&self) -> usize {
3736        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3737            .ok()
3738            .and_then(|v| v.parse().ok())
3739            .unwrap_or(600) // Picked arbitrarily
3740    }
3741
3742    fn refill_duration(&self) -> chrono::Duration {
3743        chrono::Duration::hours(1)
3744    }
3745
3746    fn db_name(&self) -> &'static str {
3747        "zed-pro:count-language-model-tokens"
3748    }
3749}
3750
3751struct FreeCountLanguageModelTokensRateLimit;
3752
3753impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3754    fn capacity(&self) -> usize {
3755        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3756            .ok()
3757            .and_then(|v| v.parse().ok())
3758            .unwrap_or(600 / 10) // Picked arbitrarily
3759    }
3760
3761    fn refill_duration(&self) -> chrono::Duration {
3762        chrono::Duration::hours(1)
3763    }
3764
3765    fn db_name(&self) -> &'static str {
3766        "free:count-language-model-tokens"
3767    }
3768}
3769
3770struct ZedProComputeEmbeddingsRateLimit;
3771
3772impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3773    fn capacity(&self) -> usize {
3774        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3775            .ok()
3776            .and_then(|v| v.parse().ok())
3777            .unwrap_or(5000) // Picked arbitrarily
3778    }
3779
3780    fn refill_duration(&self) -> chrono::Duration {
3781        chrono::Duration::hours(1)
3782    }
3783
3784    fn db_name(&self) -> &'static str {
3785        "zed-pro:compute-embeddings"
3786    }
3787}
3788
3789struct FreeComputeEmbeddingsRateLimit;
3790
3791impl RateLimit for FreeComputeEmbeddingsRateLimit {
3792    fn capacity(&self) -> usize {
3793        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3794            .ok()
3795            .and_then(|v| v.parse().ok())
3796            .unwrap_or(5000 / 10) // Picked arbitrarily
3797    }
3798
3799    fn refill_duration(&self) -> chrono::Duration {
3800        chrono::Duration::hours(1)
3801    }
3802
3803    fn db_name(&self) -> &'static str {
3804        "free:compute-embeddings"
3805    }
3806}
3807
3808async fn compute_embeddings(
3809    request: proto::ComputeEmbeddings,
3810    response: Response<proto::ComputeEmbeddings>,
3811    session: Session,
3812    api_key: Option<Arc<str>>,
3813) -> Result<()> {
3814    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3815    authorize_access_to_legacy_llm_endpoints(&session).await?;
3816
3817    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3818        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3819        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3820    };
3821
3822    session
3823        .app_state
3824        .rate_limiter
3825        .check(&*rate_limit, session.user_id())
3826        .await?;
3827
3828    let embeddings = match request.model.as_str() {
3829        "openai/text-embedding-3-small" => {
3830            open_ai::embed(
3831                session.http_client.as_ref(),
3832                OPEN_AI_API_URL,
3833                &api_key,
3834                OpenAiEmbeddingModel::TextEmbedding3Small,
3835                request.texts.iter().map(|text| text.as_str()),
3836            )
3837            .await?
3838        }
3839        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3840    };
3841
3842    let embeddings = request
3843        .texts
3844        .iter()
3845        .map(|text| {
3846            let mut hasher = sha2::Sha256::new();
3847            hasher.update(text.as_bytes());
3848            let result = hasher.finalize();
3849            result.to_vec()
3850        })
3851        .zip(
3852            embeddings
3853                .data
3854                .into_iter()
3855                .map(|embedding| embedding.embedding),
3856        )
3857        .collect::<HashMap<_, _>>();
3858
3859    let db = session.db().await;
3860    db.save_embeddings(&request.model, &embeddings)
3861        .await
3862        .context("failed to save embeddings")
3863        .trace_err();
3864
3865    response.send(proto::ComputeEmbeddingsResponse {
3866        embeddings: embeddings
3867            .into_iter()
3868            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3869            .collect(),
3870    })?;
3871    Ok(())
3872}
3873
3874async fn get_cached_embeddings(
3875    request: proto::GetCachedEmbeddings,
3876    response: Response<proto::GetCachedEmbeddings>,
3877    session: Session,
3878) -> Result<()> {
3879    authorize_access_to_legacy_llm_endpoints(&session).await?;
3880
3881    let db = session.db().await;
3882    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3883
3884    response.send(proto::GetCachedEmbeddingsResponse {
3885        embeddings: embeddings
3886            .into_iter()
3887            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3888            .collect(),
3889    })?;
3890    Ok(())
3891}
3892
3893/// This is leftover from before the LLM service.
3894///
3895/// The endpoints protected by this check will be moved there eventually.
3896async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3897    if session.is_staff() {
3898        Ok(())
3899    } else {
3900        Err(anyhow!("permission denied"))?
3901    }
3902}
3903
3904/// Get a Supermaven API key for the user
3905async fn get_supermaven_api_key(
3906    _request: proto::GetSupermavenApiKey,
3907    response: Response<proto::GetSupermavenApiKey>,
3908    session: Session,
3909) -> Result<()> {
3910    let user_id: String = session.user_id().to_string();
3911    if !session.is_staff() {
3912        return Err(anyhow!("supermaven not enabled for this account"))?;
3913    }
3914
3915    let email = session
3916        .email()
3917        .ok_or_else(|| anyhow!("user must have an email"))?;
3918
3919    let supermaven_admin_api = session
3920        .supermaven_client
3921        .as_ref()
3922        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3923
3924    let result = supermaven_admin_api
3925        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3926        .await?;
3927
3928    response.send(proto::GetSupermavenApiKeyResponse {
3929        api_key: result.api_key,
3930    })?;
3931
3932    Ok(())
3933}
3934
3935/// Start receiving chat updates for a channel
3936async fn join_channel_chat(
3937    request: proto::JoinChannelChat,
3938    response: Response<proto::JoinChannelChat>,
3939    session: Session,
3940) -> Result<()> {
3941    let channel_id = ChannelId::from_proto(request.channel_id);
3942
3943    let db = session.db().await;
3944    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3945        .await?;
3946    let messages = db
3947        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3948        .await?;
3949    response.send(proto::JoinChannelChatResponse {
3950        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3951        messages,
3952    })?;
3953    Ok(())
3954}
3955
3956/// Stop receiving chat updates for a channel
3957async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3958    let channel_id = ChannelId::from_proto(request.channel_id);
3959    session
3960        .db()
3961        .await
3962        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3963        .await?;
3964    Ok(())
3965}
3966
3967/// Retrieve the chat history for a channel
3968async fn get_channel_messages(
3969    request: proto::GetChannelMessages,
3970    response: Response<proto::GetChannelMessages>,
3971    session: Session,
3972) -> Result<()> {
3973    let channel_id = ChannelId::from_proto(request.channel_id);
3974    let messages = session
3975        .db()
3976        .await
3977        .get_channel_messages(
3978            channel_id,
3979            session.user_id(),
3980            MESSAGE_COUNT_PER_PAGE,
3981            Some(MessageId::from_proto(request.before_message_id)),
3982        )
3983        .await?;
3984    response.send(proto::GetChannelMessagesResponse {
3985        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3986        messages,
3987    })?;
3988    Ok(())
3989}
3990
3991/// Retrieve specific chat messages
3992async fn get_channel_messages_by_id(
3993    request: proto::GetChannelMessagesById,
3994    response: Response<proto::GetChannelMessagesById>,
3995    session: Session,
3996) -> Result<()> {
3997    let message_ids = request
3998        .message_ids
3999        .iter()
4000        .map(|id| MessageId::from_proto(*id))
4001        .collect::<Vec<_>>();
4002    let messages = session
4003        .db()
4004        .await
4005        .get_channel_messages_by_id(session.user_id(), &message_ids)
4006        .await?;
4007    response.send(proto::GetChannelMessagesResponse {
4008        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4009        messages,
4010    })?;
4011    Ok(())
4012}
4013
4014/// Retrieve the current users notifications
4015async fn get_notifications(
4016    request: proto::GetNotifications,
4017    response: Response<proto::GetNotifications>,
4018    session: Session,
4019) -> Result<()> {
4020    let notifications = session
4021        .db()
4022        .await
4023        .get_notifications(
4024            session.user_id(),
4025            NOTIFICATION_COUNT_PER_PAGE,
4026            request.before_id.map(db::NotificationId::from_proto),
4027        )
4028        .await?;
4029    response.send(proto::GetNotificationsResponse {
4030        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4031        notifications,
4032    })?;
4033    Ok(())
4034}
4035
4036/// Mark notifications as read
4037async fn mark_notification_as_read(
4038    request: proto::MarkNotificationRead,
4039    response: Response<proto::MarkNotificationRead>,
4040    session: Session,
4041) -> Result<()> {
4042    let database = &session.db().await;
4043    let notifications = database
4044        .mark_notification_as_read_by_id(
4045            session.user_id(),
4046            NotificationId::from_proto(request.notification_id),
4047        )
4048        .await?;
4049    send_notifications(
4050        &*session.connection_pool().await,
4051        &session.peer,
4052        notifications,
4053    );
4054    response.send(proto::Ack {})?;
4055    Ok(())
4056}
4057
4058/// Get the current users information
4059async fn get_private_user_info(
4060    _request: proto::GetPrivateUserInfo,
4061    response: Response<proto::GetPrivateUserInfo>,
4062    session: Session,
4063) -> Result<()> {
4064    let db = session.db().await;
4065
4066    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4067    let user = db
4068        .get_user_by_id(session.user_id())
4069        .await?
4070        .ok_or_else(|| anyhow!("user not found"))?;
4071    let flags = db.get_user_flags(session.user_id()).await?;
4072
4073    response.send(proto::GetPrivateUserInfoResponse {
4074        metrics_id,
4075        staff: user.admin,
4076        flags,
4077        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4078    })?;
4079    Ok(())
4080}
4081
4082/// Accept the terms of service (tos) on behalf of the current user
4083async fn accept_terms_of_service(
4084    _request: proto::AcceptTermsOfService,
4085    response: Response<proto::AcceptTermsOfService>,
4086    session: Session,
4087) -> Result<()> {
4088    let db = session.db().await;
4089
4090    let accepted_tos_at = Utc::now();
4091    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4092        .await?;
4093
4094    response.send(proto::AcceptTermsOfServiceResponse {
4095        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4096    })?;
4097    Ok(())
4098}
4099
4100/// The minimum account age an account must have in order to use the LLM service.
4101pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4102
4103async fn get_llm_api_token(
4104    _request: proto::GetLlmToken,
4105    response: Response<proto::GetLlmToken>,
4106    session: Session,
4107) -> Result<()> {
4108    let db = session.db().await;
4109
4110    let flags = db.get_user_flags(session.user_id()).await?;
4111    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4112
4113    if !session.is_staff() && !has_language_models_feature_flag {
4114        Err(anyhow!("permission denied"))?
4115    }
4116
4117    let user_id = session.user_id();
4118    let user = db
4119        .get_user_by_id(user_id)
4120        .await?
4121        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4122
4123    if user.accepted_tos_at.is_none() {
4124        Err(anyhow!("terms of service not accepted"))?
4125    }
4126
4127    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4128    let billing_preferences = db.get_billing_preferences(user.id).await?;
4129
4130    let token = LlmTokenClaims::create(
4131        &user,
4132        session.is_staff(),
4133        billing_preferences,
4134        &flags,
4135        has_llm_subscription,
4136        session.current_plan(&db).await?,
4137        session.system_id.clone(),
4138        &session.app_state.config,
4139    )?;
4140    response.send(proto::GetLlmTokenResponse { token })?;
4141    Ok(())
4142}
4143
4144fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4145    let message = match message {
4146        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4147        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4148        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4149        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4150        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4151            code: frame.code.into(),
4152            reason: frame.reason,
4153        })),
4154        // We should never receive a frame while reading the message, according
4155        // to the `tungstenite` maintainers:
4156        //
4157        // > It cannot occur when you read messages from the WebSocket, but it
4158        // > can be used when you want to send the raw frames (e.g. you want to
4159        // > send the frames to the WebSocket without composing the full message first).
4160        // >
4161        // > — https://github.com/snapview/tungstenite-rs/issues/268
4162        TungsteniteMessage::Frame(_) => {
4163            bail!("received an unexpected frame while reading the message")
4164        }
4165    };
4166
4167    Ok(message)
4168}
4169
4170fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4171    match message {
4172        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4173        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4174        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4175        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4176        AxumMessage::Close(frame) => {
4177            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4178                code: frame.code.into(),
4179                reason: frame.reason,
4180            }))
4181        }
4182    }
4183}
4184
4185fn notify_membership_updated(
4186    connection_pool: &mut ConnectionPool,
4187    result: MembershipUpdated,
4188    user_id: UserId,
4189    peer: &Peer,
4190) {
4191    for membership in &result.new_channels.channel_memberships {
4192        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4193    }
4194    for channel_id in &result.removed_channels {
4195        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4196    }
4197
4198    let user_channels_update = proto::UpdateUserChannels {
4199        channel_memberships: result
4200            .new_channels
4201            .channel_memberships
4202            .iter()
4203            .map(|cm| proto::ChannelMembership {
4204                channel_id: cm.channel_id.to_proto(),
4205                role: cm.role.into(),
4206            })
4207            .collect(),
4208        ..Default::default()
4209    };
4210
4211    let mut update = build_channels_update(result.new_channels);
4212    update.delete_channels = result
4213        .removed_channels
4214        .into_iter()
4215        .map(|id| id.to_proto())
4216        .collect();
4217    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4218
4219    for connection_id in connection_pool.user_connection_ids(user_id) {
4220        peer.send(connection_id, user_channels_update.clone())
4221            .trace_err();
4222        peer.send(connection_id, update.clone()).trace_err();
4223    }
4224}
4225
4226fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4227    proto::UpdateUserChannels {
4228        channel_memberships: channels
4229            .channel_memberships
4230            .iter()
4231            .map(|m| proto::ChannelMembership {
4232                channel_id: m.channel_id.to_proto(),
4233                role: m.role.into(),
4234            })
4235            .collect(),
4236        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4237        observed_channel_message_id: channels.observed_channel_messages.clone(),
4238    }
4239}
4240
4241fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4242    let mut update = proto::UpdateChannels::default();
4243
4244    for channel in channels.channels {
4245        update.channels.push(channel.to_proto());
4246    }
4247
4248    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4249    update.latest_channel_message_ids = channels.latest_channel_messages;
4250
4251    for (channel_id, participants) in channels.channel_participants {
4252        update
4253            .channel_participants
4254            .push(proto::ChannelParticipants {
4255                channel_id: channel_id.to_proto(),
4256                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4257            });
4258    }
4259
4260    for channel in channels.invited_channels {
4261        update.channel_invitations.push(channel.to_proto());
4262    }
4263
4264    update
4265}
4266
4267fn build_initial_contacts_update(
4268    contacts: Vec<db::Contact>,
4269    pool: &ConnectionPool,
4270) -> proto::UpdateContacts {
4271    let mut update = proto::UpdateContacts::default();
4272
4273    for contact in contacts {
4274        match contact {
4275            db::Contact::Accepted { user_id, busy } => {
4276                update.contacts.push(contact_for_user(user_id, busy, pool));
4277            }
4278            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4279            db::Contact::Incoming { user_id } => {
4280                update
4281                    .incoming_requests
4282                    .push(proto::IncomingContactRequest {
4283                        requester_id: user_id.to_proto(),
4284                    })
4285            }
4286        }
4287    }
4288
4289    update
4290}
4291
4292fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4293    proto::Contact {
4294        user_id: user_id.to_proto(),
4295        online: pool.is_user_online(user_id),
4296        busy,
4297    }
4298}
4299
4300fn room_updated(room: &proto::Room, peer: &Peer) {
4301    broadcast(
4302        None,
4303        room.participants
4304            .iter()
4305            .filter_map(|participant| Some(participant.peer_id?.into())),
4306        |peer_id| {
4307            peer.send(
4308                peer_id,
4309                proto::RoomUpdated {
4310                    room: Some(room.clone()),
4311                },
4312            )
4313        },
4314    );
4315}
4316
4317fn channel_updated(
4318    channel: &db::channel::Model,
4319    room: &proto::Room,
4320    peer: &Peer,
4321    pool: &ConnectionPool,
4322) {
4323    let participants = room
4324        .participants
4325        .iter()
4326        .map(|p| p.user_id)
4327        .collect::<Vec<_>>();
4328
4329    broadcast(
4330        None,
4331        pool.channel_connection_ids(channel.root_id())
4332            .filter_map(|(channel_id, role)| {
4333                role.can_see_channel(channel.visibility)
4334                    .then_some(channel_id)
4335            }),
4336        |peer_id| {
4337            peer.send(
4338                peer_id,
4339                proto::UpdateChannels {
4340                    channel_participants: vec![proto::ChannelParticipants {
4341                        channel_id: channel.id.to_proto(),
4342                        participant_user_ids: participants.clone(),
4343                    }],
4344                    ..Default::default()
4345                },
4346            )
4347        },
4348    );
4349}
4350
4351async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4352    let db = session.db().await;
4353
4354    let contacts = db.get_contacts(user_id).await?;
4355    let busy = db.is_user_busy(user_id).await?;
4356
4357    let pool = session.connection_pool().await;
4358    let updated_contact = contact_for_user(user_id, busy, &pool);
4359    for contact in contacts {
4360        if let db::Contact::Accepted {
4361            user_id: contact_user_id,
4362            ..
4363        } = contact
4364        {
4365            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4366                session
4367                    .peer
4368                    .send(
4369                        contact_conn_id,
4370                        proto::UpdateContacts {
4371                            contacts: vec![updated_contact.clone()],
4372                            remove_contacts: Default::default(),
4373                            incoming_requests: Default::default(),
4374                            remove_incoming_requests: Default::default(),
4375                            outgoing_requests: Default::default(),
4376                            remove_outgoing_requests: Default::default(),
4377                        },
4378                    )
4379                    .trace_err();
4380            }
4381        }
4382    }
4383    Ok(())
4384}
4385
4386async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4387    let mut contacts_to_update = HashSet::default();
4388
4389    let room_id;
4390    let canceled_calls_to_user_ids;
4391    let livekit_room;
4392    let delete_livekit_room;
4393    let room;
4394    let channel;
4395
4396    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4397        contacts_to_update.insert(session.user_id());
4398
4399        for project in left_room.left_projects.values() {
4400            project_left(project, session);
4401        }
4402
4403        room_id = RoomId::from_proto(left_room.room.id);
4404        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4405        livekit_room = mem::take(&mut left_room.room.livekit_room);
4406        delete_livekit_room = left_room.deleted;
4407        room = mem::take(&mut left_room.room);
4408        channel = mem::take(&mut left_room.channel);
4409
4410        room_updated(&room, &session.peer);
4411    } else {
4412        return Ok(());
4413    }
4414
4415    if let Some(channel) = channel {
4416        channel_updated(
4417            &channel,
4418            &room,
4419            &session.peer,
4420            &*session.connection_pool().await,
4421        );
4422    }
4423
4424    {
4425        let pool = session.connection_pool().await;
4426        for canceled_user_id in canceled_calls_to_user_ids {
4427            for connection_id in pool.user_connection_ids(canceled_user_id) {
4428                session
4429                    .peer
4430                    .send(
4431                        connection_id,
4432                        proto::CallCanceled {
4433                            room_id: room_id.to_proto(),
4434                        },
4435                    )
4436                    .trace_err();
4437            }
4438            contacts_to_update.insert(canceled_user_id);
4439        }
4440    }
4441
4442    for contact_user_id in contacts_to_update {
4443        update_user_contacts(contact_user_id, session).await?;
4444    }
4445
4446    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4447        live_kit
4448            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4449            .await
4450            .trace_err();
4451
4452        if delete_livekit_room {
4453            live_kit.delete_room(livekit_room).await.trace_err();
4454        }
4455    }
4456
4457    Ok(())
4458}
4459
4460async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4461    let left_channel_buffers = session
4462        .db()
4463        .await
4464        .leave_channel_buffers(session.connection_id)
4465        .await?;
4466
4467    for left_buffer in left_channel_buffers {
4468        channel_buffer_updated(
4469            session.connection_id,
4470            left_buffer.connections,
4471            &proto::UpdateChannelBufferCollaborators {
4472                channel_id: left_buffer.channel_id.to_proto(),
4473                collaborators: left_buffer.collaborators,
4474            },
4475            &session.peer,
4476        );
4477    }
4478
4479    Ok(())
4480}
4481
4482fn project_left(project: &db::LeftProject, session: &Session) {
4483    for connection_id in &project.connection_ids {
4484        if project.should_unshare {
4485            session
4486                .peer
4487                .send(
4488                    *connection_id,
4489                    proto::UnshareProject {
4490                        project_id: project.id.to_proto(),
4491                    },
4492                )
4493                .trace_err();
4494        } else {
4495            session
4496                .peer
4497                .send(
4498                    *connection_id,
4499                    proto::RemoveProjectCollaborator {
4500                        project_id: project.id.to_proto(),
4501                        peer_id: Some(session.connection_id.into()),
4502                    },
4503                )
4504                .trace_err();
4505        }
4506    }
4507}
4508
4509pub trait ResultExt {
4510    type Ok;
4511
4512    fn trace_err(self) -> Option<Self::Ok>;
4513}
4514
4515impl<T, E> ResultExt for Result<T, E>
4516where
4517    E: std::fmt::Debug,
4518{
4519    type Ok = T;
4520
4521    #[track_caller]
4522    fn trace_err(self) -> Option<T> {
4523        match self {
4524            Ok(value) => Some(value),
4525            Err(error) => {
4526                tracing::error!("{:?}", error);
4527                None
4528            }
4529        }
4530    }
4531}