rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use rpc::proto::split_repository_update;
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  46    TryStreamExt,
  47};
  48use prometheus::{register_int_gauge, IntGauge};
  49use rpc::{
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68        Arc, OnceLock,
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{watch, MutexGuard, Semaphore};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    field::{self},
  77    info_span, instrument, Instrument,
  78};
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106#[derive(Clone, Debug)]
 107pub enum Principal {
 108    User(User),
 109    Impersonated { user: User, admin: User },
 110}
 111
 112impl Principal {
 113    fn update_span(&self, span: &tracing::Span) {
 114        match &self {
 115            Principal::User(user) => {
 116                span.record("user_id", user.id.0);
 117                span.record("login", &user.github_login);
 118            }
 119            Principal::Impersonated { user, admin } => {
 120                span.record("user_id", user.id.0);
 121                span.record("login", &user.github_login);
 122                span.record("impersonator", &admin.github_login);
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Clone)]
 129struct Session {
 130    principal: Principal,
 131    connection_id: ConnectionId,
 132    db: Arc<tokio::sync::Mutex<DbHandle>>,
 133    peer: Arc<Peer>,
 134    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 135    app_state: Arc<AppState>,
 136    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 137    http_client: Arc<dyn HttpClient>,
 138    /// The GeoIP country code for the user.
 139    #[allow(unused)]
 140    geoip_country_code: Option<String>,
 141    system_id: Option<String>,
 142    _executor: Executor,
 143}
 144
 145impl Session {
 146    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        let guard = self.db.lock().await;
 150        #[cfg(test)]
 151        tokio::task::yield_now().await;
 152        guard
 153    }
 154
 155    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.connection_pool.lock();
 159        ConnectionPoolGuard {
 160            guard,
 161            _not_send: PhantomData,
 162        }
 163    }
 164
 165    fn is_staff(&self) -> bool {
 166        match &self.principal {
 167            Principal::User(user) => user.admin,
 168            Principal::Impersonated { .. } => true,
 169        }
 170    }
 171
 172    pub async fn has_llm_subscription(
 173        &self,
 174        db: &MutexGuard<'_, DbHandle>,
 175    ) -> anyhow::Result<bool> {
 176        if self.is_staff() {
 177            return Ok(true);
 178        }
 179
 180        let user_id = self.user_id();
 181
 182        Ok(db.has_active_billing_subscription(user_id).await?)
 183    }
 184
 185    pub async fn current_plan(
 186        &self,
 187        _db: &MutexGuard<'_, DbHandle>,
 188    ) -> anyhow::Result<proto::Plan> {
 189        if self.is_staff() {
 190            Ok(proto::Plan::ZedPro)
 191        } else {
 192            Ok(proto::Plan::Free)
 193        }
 194    }
 195
 196    fn user_id(&self) -> UserId {
 197        match &self.principal {
 198            Principal::User(user) => user.id,
 199            Principal::Impersonated { user, .. } => user.id,
 200        }
 201    }
 202
 203    pub fn email(&self) -> Option<String> {
 204        match &self.principal {
 205            Principal::User(user) => user.email_address.clone(),
 206            Principal::Impersonated { user, .. } => user.email_address.clone(),
 207        }
 208    }
 209}
 210
 211impl Debug for Session {
 212    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 213        let mut result = f.debug_struct("Session");
 214        match &self.principal {
 215            Principal::User(user) => {
 216                result.field("user", &user.github_login);
 217            }
 218            Principal::Impersonated { user, admin } => {
 219                result.field("user", &user.github_login);
 220                result.field("impersonator", &admin.github_login);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct DbHandle(Arc<Database>);
 228
 229impl Deref for DbHandle {
 230    type Target = Database;
 231
 232    fn deref(&self) -> &Self::Target {
 233        self.0.as_ref()
 234    }
 235}
 236
 237pub struct Server {
 238    id: parking_lot::Mutex<ServerId>,
 239    peer: Arc<Peer>,
 240    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 241    app_state: Arc<AppState>,
 242    handlers: HashMap<TypeId, MessageHandler>,
 243    teardown: watch::Sender<bool>,
 244}
 245
 246pub(crate) struct ConnectionPoolGuard<'a> {
 247    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 248    _not_send: PhantomData<Rc<()>>,
 249}
 250
 251#[derive(Serialize)]
 252pub struct ServerSnapshot<'a> {
 253    peer: &'a Peer,
 254    #[serde(serialize_with = "serialize_deref")]
 255    connection_pool: ConnectionPoolGuard<'a>,
 256}
 257
 258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 259where
 260    S: Serializer,
 261    T: Deref<Target = U>,
 262    U: Serialize,
 263{
 264    Serialize::serialize(value.deref(), serializer)
 265}
 266
 267impl Server {
 268    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 269        let mut server = Self {
 270            id: parking_lot::Mutex::new(id),
 271            peer: Peer::new(id.0 as u32),
 272            app_state: app_state.clone(),
 273            connection_pool: Default::default(),
 274            handlers: Default::default(),
 275            teardown: watch::channel(false).0,
 276        };
 277
 278        server
 279            .add_request_handler(ping)
 280            .add_request_handler(create_room)
 281            .add_request_handler(join_room)
 282            .add_request_handler(rejoin_room)
 283            .add_request_handler(leave_room)
 284            .add_request_handler(set_room_participant_role)
 285            .add_request_handler(call)
 286            .add_request_handler(cancel_call)
 287            .add_message_handler(decline_call)
 288            .add_request_handler(update_participant_location)
 289            .add_request_handler(share_project)
 290            .add_message_handler(unshare_project)
 291            .add_request_handler(join_project)
 292            .add_message_handler(leave_project)
 293            .add_request_handler(update_project)
 294            .add_request_handler(update_worktree)
 295            .add_request_handler(update_repository)
 296            .add_request_handler(remove_repository)
 297            .add_message_handler(start_language_server)
 298            .add_message_handler(update_language_server)
 299            .add_message_handler(update_diagnostic_summary)
 300            .add_message_handler(update_worktree_settings)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 305            .add_request_handler(forward_find_search_candidates_request)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 307            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 308            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 310            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 311            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 312            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 313            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 322            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 327            .add_request_handler(
 328                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 329            )
 330            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 332            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 333            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 334            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 336            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 337            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 339            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 340            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 341            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 342            .add_request_handler(
 343                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 344            )
 345            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 346            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 347            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 348            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 349            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 350            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 351            .add_message_handler(create_buffer_for_peer)
 352            .add_request_handler(update_buffer)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 357            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 359            .add_request_handler(get_users)
 360            .add_request_handler(fuzzy_search_users)
 361            .add_request_handler(request_contact)
 362            .add_request_handler(remove_contact)
 363            .add_request_handler(respond_to_contact_request)
 364            .add_message_handler(subscribe_to_channels)
 365            .add_request_handler(create_channel)
 366            .add_request_handler(delete_channel)
 367            .add_request_handler(invite_channel_member)
 368            .add_request_handler(remove_channel_member)
 369            .add_request_handler(set_channel_member_role)
 370            .add_request_handler(set_channel_visibility)
 371            .add_request_handler(rename_channel)
 372            .add_request_handler(join_channel_buffer)
 373            .add_request_handler(leave_channel_buffer)
 374            .add_message_handler(update_channel_buffer)
 375            .add_request_handler(rejoin_channel_buffers)
 376            .add_request_handler(get_channel_members)
 377            .add_request_handler(respond_to_channel_invite)
 378            .add_request_handler(join_channel)
 379            .add_request_handler(join_channel_chat)
 380            .add_message_handler(leave_channel_chat)
 381            .add_request_handler(send_channel_message)
 382            .add_request_handler(remove_channel_message)
 383            .add_request_handler(update_channel_message)
 384            .add_request_handler(get_channel_messages)
 385            .add_request_handler(get_channel_messages_by_id)
 386            .add_request_handler(get_notifications)
 387            .add_request_handler(mark_notification_as_read)
 388            .add_request_handler(move_channel)
 389            .add_request_handler(follow)
 390            .add_message_handler(unfollow)
 391            .add_message_handler(update_followers)
 392            .add_request_handler(get_private_user_info)
 393            .add_request_handler(get_llm_api_token)
 394            .add_request_handler(accept_terms_of_service)
 395            .add_message_handler(acknowledge_channel_message)
 396            .add_message_handler(acknowledge_buffer_version)
 397            .add_request_handler(get_supermaven_api_key)
 398            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 399            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 401            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 402            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 403            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 404            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 406            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 407            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 408            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 409            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 410            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 412            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 414            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 415            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 416            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 417            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 418            .add_message_handler(update_context)
 419            .add_request_handler({
 420                let app_state = app_state.clone();
 421                move |request, response, session| {
 422                    let app_state = app_state.clone();
 423                    async move {
 424                        count_language_model_tokens(request, response, session, &app_state.config)
 425                            .await
 426                    }
 427                }
 428            })
 429            .add_request_handler(get_cached_embeddings)
 430            .add_request_handler({
 431                let app_state = app_state.clone();
 432                move |request, response, session| {
 433                    compute_embeddings(
 434                        request,
 435                        response,
 436                        session,
 437                        app_state.config.openai_api_key.clone(),
 438                    )
 439                }
 440            });
 441
 442        Arc::new(server)
 443    }
 444
 445    pub async fn start(&self) -> Result<()> {
 446        let server_id = *self.id.lock();
 447        let app_state = self.app_state.clone();
 448        let peer = self.peer.clone();
 449        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 450        let pool = self.connection_pool.clone();
 451        let livekit_client = self.app_state.livekit_client.clone();
 452
 453        let span = info_span!("start server");
 454        self.app_state.executor.spawn_detached(
 455            async move {
 456                tracing::info!("waiting for cleanup timeout");
 457                timeout.await;
 458                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 459                if let Some((room_ids, channel_ids)) = app_state
 460                    .db
 461                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 462                    .await
 463                    .trace_err()
 464                {
 465                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 466                    tracing::info!(
 467                        stale_channel_buffer_count = channel_ids.len(),
 468                        "retrieved stale channel buffers"
 469                    );
 470
 471                    for channel_id in channel_ids {
 472                        if let Some(refreshed_channel_buffer) = app_state
 473                            .db
 474                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 475                            .await
 476                            .trace_err()
 477                        {
 478                            for connection_id in refreshed_channel_buffer.connection_ids {
 479                                peer.send(
 480                                    connection_id,
 481                                    proto::UpdateChannelBufferCollaborators {
 482                                        channel_id: channel_id.to_proto(),
 483                                        collaborators: refreshed_channel_buffer
 484                                            .collaborators
 485                                            .clone(),
 486                                    },
 487                                )
 488                                .trace_err();
 489                            }
 490                        }
 491                    }
 492
 493                    for room_id in room_ids {
 494                        let mut contacts_to_update = HashSet::default();
 495                        let mut canceled_calls_to_user_ids = Vec::new();
 496                        let mut livekit_room = String::new();
 497                        let mut delete_livekit_room = false;
 498
 499                        if let Some(mut refreshed_room) = app_state
 500                            .db
 501                            .clear_stale_room_participants(room_id, server_id)
 502                            .await
 503                            .trace_err()
 504                        {
 505                            tracing::info!(
 506                                room_id = room_id.0,
 507                                new_participant_count = refreshed_room.room.participants.len(),
 508                                "refreshed room"
 509                            );
 510                            room_updated(&refreshed_room.room, &peer);
 511                            if let Some(channel) = refreshed_room.channel.as_ref() {
 512                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 513                            }
 514                            contacts_to_update
 515                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 516                            contacts_to_update
 517                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 518                            canceled_calls_to_user_ids =
 519                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 520                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 521                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 522                        }
 523
 524                        {
 525                            let pool = pool.lock();
 526                            for canceled_user_id in canceled_calls_to_user_ids {
 527                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 528                                    peer.send(
 529                                        connection_id,
 530                                        proto::CallCanceled {
 531                                            room_id: room_id.to_proto(),
 532                                        },
 533                                    )
 534                                    .trace_err();
 535                                }
 536                            }
 537                        }
 538
 539                        for user_id in contacts_to_update {
 540                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 541                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 542                            if let Some((busy, contacts)) = busy.zip(contacts) {
 543                                let pool = pool.lock();
 544                                let updated_contact = contact_for_user(user_id, busy, &pool);
 545                                for contact in contacts {
 546                                    if let db::Contact::Accepted {
 547                                        user_id: contact_user_id,
 548                                        ..
 549                                    } = contact
 550                                    {
 551                                        for contact_conn_id in
 552                                            pool.user_connection_ids(contact_user_id)
 553                                        {
 554                                            peer.send(
 555                                                contact_conn_id,
 556                                                proto::UpdateContacts {
 557                                                    contacts: vec![updated_contact.clone()],
 558                                                    remove_contacts: Default::default(),
 559                                                    incoming_requests: Default::default(),
 560                                                    remove_incoming_requests: Default::default(),
 561                                                    outgoing_requests: Default::default(),
 562                                                    remove_outgoing_requests: Default::default(),
 563                                                },
 564                                            )
 565                                            .trace_err();
 566                                        }
 567                                    }
 568                                }
 569                            }
 570                        }
 571
 572                        if let Some(live_kit) = livekit_client.as_ref() {
 573                            if delete_livekit_room {
 574                                live_kit.delete_room(livekit_room).await.trace_err();
 575                            }
 576                        }
 577                    }
 578                }
 579
 580                app_state
 581                    .db
 582                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 583                    .await
 584                    .trace_err();
 585            }
 586            .instrument(span),
 587        );
 588        Ok(())
 589    }
 590
 591    pub fn teardown(&self) {
 592        self.peer.teardown();
 593        self.connection_pool.lock().reset();
 594        let _ = self.teardown.send(true);
 595    }
 596
 597    #[cfg(test)]
 598    pub fn reset(&self, id: ServerId) {
 599        self.teardown();
 600        *self.id.lock() = id;
 601        self.peer.reset(id.0 as u32);
 602        let _ = self.teardown.send(false);
 603    }
 604
 605    #[cfg(test)]
 606    pub fn id(&self) -> ServerId {
 607        *self.id.lock()
 608    }
 609
 610    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 611    where
 612        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 613        Fut: 'static + Send + Future<Output = Result<()>>,
 614        M: EnvelopedMessage,
 615    {
 616        let prev_handler = self.handlers.insert(
 617            TypeId::of::<M>(),
 618            Box::new(move |envelope, session| {
 619                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 620                let received_at = envelope.received_at;
 621                tracing::info!("message received");
 622                let start_time = Instant::now();
 623                let future = (handler)(*envelope, session);
 624                async move {
 625                    let result = future.await;
 626                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 627                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 628                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 629                    let payload_type = M::NAME;
 630
 631                    match result {
 632                        Err(error) => {
 633                            tracing::error!(
 634                                ?error,
 635                                total_duration_ms,
 636                                processing_duration_ms,
 637                                queue_duration_ms,
 638                                payload_type,
 639                                "error handling message"
 640                            )
 641                        }
 642                        Ok(()) => tracing::info!(
 643                            total_duration_ms,
 644                            processing_duration_ms,
 645                            queue_duration_ms,
 646                            "finished handling message"
 647                        ),
 648                    }
 649                }
 650                .boxed()
 651            }),
 652        );
 653        if prev_handler.is_some() {
 654            panic!("registered a handler for the same message twice");
 655        }
 656        self
 657    }
 658
 659    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 660    where
 661        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 662        Fut: 'static + Send + Future<Output = Result<()>>,
 663        M: EnvelopedMessage,
 664    {
 665        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 666        self
 667    }
 668
 669    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 670    where
 671        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 672        Fut: Send + Future<Output = Result<()>>,
 673        M: RequestMessage,
 674    {
 675        let handler = Arc::new(handler);
 676        self.add_handler(move |envelope, session| {
 677            let receipt = envelope.receipt();
 678            let handler = handler.clone();
 679            async move {
 680                let peer = session.peer.clone();
 681                let responded = Arc::new(AtomicBool::default());
 682                let response = Response {
 683                    peer: peer.clone(),
 684                    responded: responded.clone(),
 685                    receipt,
 686                };
 687                match (handler)(envelope.payload, response, session).await {
 688                    Ok(()) => {
 689                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 690                            Ok(())
 691                        } else {
 692                            Err(anyhow!("handler did not send a response"))?
 693                        }
 694                    }
 695                    Err(error) => {
 696                        let proto_err = match &error {
 697                            Error::Internal(err) => err.to_proto(),
 698                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 699                        };
 700                        peer.respond_with_error(receipt, proto_err)?;
 701                        Err(error)
 702                    }
 703                }
 704            }
 705        })
 706    }
 707
 708    pub fn handle_connection(
 709        self: &Arc<Self>,
 710        connection: Connection,
 711        address: String,
 712        principal: Principal,
 713        zed_version: ZedVersion,
 714        geoip_country_code: Option<String>,
 715        system_id: Option<String>,
 716        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 717        executor: Executor,
 718    ) -> impl Future<Output = ()> {
 719        let this = self.clone();
 720        let span = info_span!("handle connection", %address,
 721            connection_id=field::Empty,
 722            user_id=field::Empty,
 723            login=field::Empty,
 724            impersonator=field::Empty,
 725            geoip_country_code=field::Empty
 726        );
 727        principal.update_span(&span);
 728        if let Some(country_code) = geoip_country_code.as_ref() {
 729            span.record("geoip_country_code", country_code);
 730        }
 731
 732        let mut teardown = self.teardown.subscribe();
 733        async move {
 734            if *teardown.borrow() {
 735                tracing::error!("server is tearing down");
 736                return
 737            }
 738            let (connection_id, handle_io, mut incoming_rx) = this
 739                .peer
 740                .add_connection(connection, {
 741                    let executor = executor.clone();
 742                    move |duration| executor.sleep(duration)
 743                });
 744            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 745
 746            tracing::info!("connection opened");
 747
 748            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 749            let http_client = match ReqwestClient::user_agent(&user_agent) {
 750                Ok(http_client) => Arc::new(http_client),
 751                Err(error) => {
 752                    tracing::error!(?error, "failed to create HTTP client");
 753                    return;
 754                }
 755            };
 756
 757            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 758                    supermaven_admin_api_key.to_string(),
 759                    http_client.clone(),
 760                )));
 761
 762            let session = Session {
 763                principal: principal.clone(),
 764                connection_id,
 765                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 766                peer: this.peer.clone(),
 767                connection_pool: this.connection_pool.clone(),
 768                app_state: this.app_state.clone(),
 769                http_client,
 770                geoip_country_code,
 771                system_id,
 772                _executor: executor.clone(),
 773                supermaven_client,
 774            };
 775
 776            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 777                tracing::error!(?error, "failed to send initial client update");
 778                return;
 779            }
 780
 781            let handle_io = handle_io.fuse();
 782            futures::pin_mut!(handle_io);
 783
 784            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 785            // This prevents deadlocks when e.g., client A performs a request to client B and
 786            // client B performs a request to client A. If both clients stop processing further
 787            // messages until their respective request completes, they won't have a chance to
 788            // respond to the other client's request and cause a deadlock.
 789            //
 790            // This arrangement ensures we will attempt to process earlier messages first, but fall
 791            // back to processing messages arrived later in the spirit of making progress.
 792            let mut foreground_message_handlers = FuturesUnordered::new();
 793            let concurrent_handlers = Arc::new(Semaphore::new(256));
 794            loop {
 795                let next_message = async {
 796                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 797                    let message = incoming_rx.next().await;
 798                    (permit, message)
 799                }.fuse();
 800                futures::pin_mut!(next_message);
 801                futures::select_biased! {
 802                    _ = teardown.changed().fuse() => return,
 803                    result = handle_io => {
 804                        if let Err(error) = result {
 805                            tracing::error!(?error, "error handling I/O");
 806                        }
 807                        break;
 808                    }
 809                    _ = foreground_message_handlers.next() => {}
 810                    next_message = next_message => {
 811                        let (permit, message) = next_message;
 812                        if let Some(message) = message {
 813                            let type_name = message.payload_type_name();
 814                            // note: we copy all the fields from the parent span so we can query them in the logs.
 815                            // (https://github.com/tokio-rs/tracing/issues/2670).
 816                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 817                                user_id=field::Empty,
 818                                login=field::Empty,
 819                                impersonator=field::Empty,
 820                            );
 821                            principal.update_span(&span);
 822                            let span_enter = span.enter();
 823                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 824                                let is_background = message.is_background();
 825                                let handle_message = (handler)(message, session.clone());
 826                                drop(span_enter);
 827
 828                                let handle_message = async move {
 829                                    handle_message.await;
 830                                    drop(permit);
 831                                }.instrument(span);
 832                                if is_background {
 833                                    executor.spawn_detached(handle_message);
 834                                } else {
 835                                    foreground_message_handlers.push(handle_message);
 836                                }
 837                            } else {
 838                                tracing::error!("no message handler");
 839                            }
 840                        } else {
 841                            tracing::info!("connection closed");
 842                            break;
 843                        }
 844                    }
 845                }
 846            }
 847
 848            drop(foreground_message_handlers);
 849            tracing::info!("signing out");
 850            if let Err(error) = connection_lost(session, teardown, executor).await {
 851                tracing::error!(?error, "error signing out");
 852            }
 853
 854        }.instrument(span)
 855    }
 856
 857    async fn send_initial_client_update(
 858        &self,
 859        connection_id: ConnectionId,
 860        principal: &Principal,
 861        zed_version: ZedVersion,
 862        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 863        session: &Session,
 864    ) -> Result<()> {
 865        self.peer.send(
 866            connection_id,
 867            proto::Hello {
 868                peer_id: Some(connection_id.into()),
 869            },
 870        )?;
 871        tracing::info!("sent hello message");
 872        if let Some(send_connection_id) = send_connection_id.take() {
 873            let _ = send_connection_id.send(connection_id);
 874        }
 875
 876        match principal {
 877            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 878                if !user.connected_once {
 879                    self.peer.send(connection_id, proto::ShowContacts {})?;
 880                    self.app_state
 881                        .db
 882                        .set_user_connected_once(user.id, true)
 883                        .await?;
 884                }
 885
 886                update_user_plan(user.id, session).await?;
 887
 888                let contacts = self.app_state.db.get_contacts(user.id).await?;
 889
 890                {
 891                    let mut pool = self.connection_pool.lock();
 892                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 893                    self.peer.send(
 894                        connection_id,
 895                        build_initial_contacts_update(contacts, &pool),
 896                    )?;
 897                }
 898
 899                if should_auto_subscribe_to_channels(zed_version) {
 900                    subscribe_user_to_channels(user.id, session).await?;
 901                }
 902
 903                if let Some(incoming_call) =
 904                    self.app_state.db.incoming_call_for_user(user.id).await?
 905                {
 906                    self.peer.send(connection_id, incoming_call)?;
 907                }
 908
 909                update_user_contacts(user.id, session).await?;
 910            }
 911        }
 912
 913        Ok(())
 914    }
 915
 916    pub async fn invite_code_redeemed(
 917        self: &Arc<Self>,
 918        inviter_id: UserId,
 919        invitee_id: UserId,
 920    ) -> Result<()> {
 921        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 922            if let Some(code) = &user.invite_code {
 923                let pool = self.connection_pool.lock();
 924                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 925                for connection_id in pool.user_connection_ids(inviter_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateContacts {
 929                            contacts: vec![invitee_contact.clone()],
 930                            ..Default::default()
 931                        },
 932                    )?;
 933                    self.peer.send(
 934                        connection_id,
 935                        proto::UpdateInviteInfo {
 936                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 937                            count: user.invite_count as u32,
 938                        },
 939                    )?;
 940                }
 941            }
 942        }
 943        Ok(())
 944    }
 945
 946    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 947        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 948            if let Some(invite_code) = &user.invite_code {
 949                let pool = self.connection_pool.lock();
 950                for connection_id in pool.user_connection_ids(user_id) {
 951                    self.peer.send(
 952                        connection_id,
 953                        proto::UpdateInviteInfo {
 954                            url: format!(
 955                                "{}{}",
 956                                self.app_state.config.invite_link_prefix, invite_code
 957                            ),
 958                            count: user.invite_count as u32,
 959                        },
 960                    )?;
 961                }
 962            }
 963        }
 964        Ok(())
 965    }
 966
 967    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 968        let pool = self.connection_pool.lock();
 969        for connection_id in pool.user_connection_ids(user_id) {
 970            self.peer
 971                .send(connection_id, proto::RefreshLlmToken {})
 972                .trace_err();
 973        }
 974    }
 975
 976    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 977        ServerSnapshot {
 978            connection_pool: ConnectionPoolGuard {
 979                guard: self.connection_pool.lock(),
 980                _not_send: PhantomData,
 981            },
 982            peer: &self.peer,
 983        }
 984    }
 985}
 986
 987impl Deref for ConnectionPoolGuard<'_> {
 988    type Target = ConnectionPool;
 989
 990    fn deref(&self) -> &Self::Target {
 991        &self.guard
 992    }
 993}
 994
 995impl DerefMut for ConnectionPoolGuard<'_> {
 996    fn deref_mut(&mut self) -> &mut Self::Target {
 997        &mut self.guard
 998    }
 999}
1000
1001impl Drop for ConnectionPoolGuard<'_> {
1002    fn drop(&mut self) {
1003        #[cfg(test)]
1004        self.check_invariants();
1005    }
1006}
1007
1008fn broadcast<F>(
1009    sender_id: Option<ConnectionId>,
1010    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1011    mut f: F,
1012) where
1013    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1014{
1015    for receiver_id in receiver_ids {
1016        if Some(receiver_id) != sender_id {
1017            if let Err(error) = f(receiver_id) {
1018                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1019            }
1020        }
1021    }
1022}
1023
1024pub struct ProtocolVersion(u32);
1025
1026impl Header for ProtocolVersion {
1027    fn name() -> &'static HeaderName {
1028        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1029        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1030    }
1031
1032    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1033    where
1034        Self: Sized,
1035        I: Iterator<Item = &'i axum::http::HeaderValue>,
1036    {
1037        let version = values
1038            .next()
1039            .ok_or_else(axum::headers::Error::invalid)?
1040            .to_str()
1041            .map_err(|_| axum::headers::Error::invalid())?
1042            .parse()
1043            .map_err(|_| axum::headers::Error::invalid())?;
1044        Ok(Self(version))
1045    }
1046
1047    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1048        values.extend([self.0.to_string().parse().unwrap()]);
1049    }
1050}
1051
1052pub struct AppVersionHeader(SemanticVersion);
1053impl Header for AppVersionHeader {
1054    fn name() -> &'static HeaderName {
1055        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1057    }
1058
1059    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060    where
1061        Self: Sized,
1062        I: Iterator<Item = &'i axum::http::HeaderValue>,
1063    {
1064        let version = values
1065            .next()
1066            .ok_or_else(axum::headers::Error::invalid)?
1067            .to_str()
1068            .map_err(|_| axum::headers::Error::invalid())?
1069            .parse()
1070            .map_err(|_| axum::headers::Error::invalid())?;
1071        Ok(Self(version))
1072    }
1073
1074    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075        values.extend([self.0.to_string().parse().unwrap()]);
1076    }
1077}
1078
1079pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1080    Router::new()
1081        .route("/rpc", get(handle_websocket_request))
1082        .layer(
1083            ServiceBuilder::new()
1084                .layer(Extension(server.app_state.clone()))
1085                .layer(middleware::from_fn(auth::validate_header)),
1086        )
1087        .route("/metrics", get(handle_metrics))
1088        .layer(Extension(server))
1089}
1090
1091pub async fn handle_websocket_request(
1092    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1093    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1094    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1095    Extension(server): Extension<Arc<Server>>,
1096    Extension(principal): Extension<Principal>,
1097    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1098    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1099    ws: WebSocketUpgrade,
1100) -> axum::response::Response {
1101    if protocol_version != rpc::PROTOCOL_VERSION {
1102        return (
1103            StatusCode::UPGRADE_REQUIRED,
1104            "client must be upgraded".to_string(),
1105        )
1106            .into_response();
1107    }
1108
1109    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1110        return (
1111            StatusCode::UPGRADE_REQUIRED,
1112            "no version header found".to_string(),
1113        )
1114            .into_response();
1115    };
1116
1117    if !version.can_collaborate() {
1118        return (
1119            StatusCode::UPGRADE_REQUIRED,
1120            "client must be upgraded".to_string(),
1121        )
1122            .into_response();
1123    }
1124
1125    let socket_address = socket_address.to_string();
1126    ws.on_upgrade(move |socket| {
1127        let socket = socket
1128            .map_ok(to_tungstenite_message)
1129            .err_into()
1130            .with(|message| async move { to_axum_message(message) });
1131        let connection = Connection::new(Box::pin(socket));
1132        async move {
1133            server
1134                .handle_connection(
1135                    connection,
1136                    socket_address,
1137                    principal,
1138                    version,
1139                    country_code_header.map(|header| header.to_string()),
1140                    system_id_header.map(|header| header.to_string()),
1141                    None,
1142                    Executor::Production,
1143                )
1144                .await;
1145        }
1146    })
1147}
1148
1149pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1150    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151    let connections_metric = CONNECTIONS_METRIC
1152        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1153
1154    let connections = server
1155        .connection_pool
1156        .lock()
1157        .connections()
1158        .filter(|connection| !connection.admin)
1159        .count();
1160    connections_metric.set(connections as _);
1161
1162    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1164        register_int_gauge!(
1165            "shared_projects",
1166            "number of open projects with one or more guests"
1167        )
1168        .unwrap()
1169    });
1170
1171    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1172    shared_projects_metric.set(shared_projects as _);
1173
1174    let encoder = prometheus::TextEncoder::new();
1175    let metric_families = prometheus::gather();
1176    let encoded_metrics = encoder
1177        .encode_to_string(&metric_families)
1178        .map_err(|err| anyhow!("{}", err))?;
1179    Ok(encoded_metrics)
1180}
1181
1182#[instrument(err, skip(executor))]
1183async fn connection_lost(
1184    session: Session,
1185    mut teardown: watch::Receiver<bool>,
1186    executor: Executor,
1187) -> Result<()> {
1188    session.peer.disconnect(session.connection_id);
1189    session
1190        .connection_pool()
1191        .await
1192        .remove_connection(session.connection_id)?;
1193
1194    session
1195        .db()
1196        .await
1197        .connection_lost(session.connection_id)
1198        .await
1199        .trace_err();
1200
1201    futures::select_biased! {
1202        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1203
1204            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1205            leave_room_for_session(&session, session.connection_id).await.trace_err();
1206            leave_channel_buffers_for_session(&session)
1207                .await
1208                .trace_err();
1209
1210            if !session
1211                .connection_pool()
1212                .await
1213                .is_user_online(session.user_id())
1214            {
1215                let db = session.db().await;
1216                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1217                    room_updated(&room, &session.peer);
1218                }
1219            }
1220
1221            update_user_contacts(session.user_id(), &session).await?;
1222        },
1223        _ = teardown.changed().fuse() => {}
1224    }
1225
1226    Ok(())
1227}
1228
1229/// Acknowledges a ping from a client, used to keep the connection alive.
1230async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1231    response.send(proto::Ack {})?;
1232    Ok(())
1233}
1234
1235/// Creates a new room for calling (outside of channels)
1236async fn create_room(
1237    _request: proto::CreateRoom,
1238    response: Response<proto::CreateRoom>,
1239    session: Session,
1240) -> Result<()> {
1241    let livekit_room = nanoid::nanoid!(30);
1242
1243    let live_kit_connection_info = util::maybe!(async {
1244        let live_kit = session.app_state.livekit_client.as_ref();
1245        let live_kit = live_kit?;
1246        let user_id = session.user_id().to_string();
1247
1248        let token = live_kit
1249            .room_token(&livekit_room, &user_id.to_string())
1250            .trace_err()?;
1251
1252        Some(proto::LiveKitConnectionInfo {
1253            server_url: live_kit.url().into(),
1254            token,
1255            can_publish: true,
1256        })
1257    })
1258    .await;
1259
1260    let room = session
1261        .db()
1262        .await
1263        .create_room(session.user_id(), session.connection_id, &livekit_room)
1264        .await?;
1265
1266    response.send(proto::CreateRoomResponse {
1267        room: Some(room.clone()),
1268        live_kit_connection_info,
1269    })?;
1270
1271    update_user_contacts(session.user_id(), &session).await?;
1272    Ok(())
1273}
1274
1275/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1276async fn join_room(
1277    request: proto::JoinRoom,
1278    response: Response<proto::JoinRoom>,
1279    session: Session,
1280) -> Result<()> {
1281    let room_id = RoomId::from_proto(request.id);
1282
1283    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1284
1285    if let Some(channel_id) = channel_id {
1286        return join_channel_internal(channel_id, Box::new(response), session).await;
1287    }
1288
1289    let joined_room = {
1290        let room = session
1291            .db()
1292            .await
1293            .join_room(room_id, session.user_id(), session.connection_id)
1294            .await?;
1295        room_updated(&room.room, &session.peer);
1296        room.into_inner()
1297    };
1298
1299    for connection_id in session
1300        .connection_pool()
1301        .await
1302        .user_connection_ids(session.user_id())
1303    {
1304        session
1305            .peer
1306            .send(
1307                connection_id,
1308                proto::CallCanceled {
1309                    room_id: room_id.to_proto(),
1310                },
1311            )
1312            .trace_err();
1313    }
1314
1315    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1316    {
1317        live_kit
1318            .room_token(
1319                &joined_room.room.livekit_room,
1320                &session.user_id().to_string(),
1321            )
1322            .trace_err()
1323            .map(|token| proto::LiveKitConnectionInfo {
1324                server_url: live_kit.url().into(),
1325                token,
1326                can_publish: true,
1327            })
1328    } else {
1329        None
1330    };
1331
1332    response.send(proto::JoinRoomResponse {
1333        room: Some(joined_room.room),
1334        channel_id: None,
1335        live_kit_connection_info,
1336    })?;
1337
1338    update_user_contacts(session.user_id(), &session).await?;
1339    Ok(())
1340}
1341
1342/// Rejoin room is used to reconnect to a room after connection errors.
1343async fn rejoin_room(
1344    request: proto::RejoinRoom,
1345    response: Response<proto::RejoinRoom>,
1346    session: Session,
1347) -> Result<()> {
1348    let room;
1349    let channel;
1350    {
1351        let mut rejoined_room = session
1352            .db()
1353            .await
1354            .rejoin_room(request, session.user_id(), session.connection_id)
1355            .await?;
1356
1357        response.send(proto::RejoinRoomResponse {
1358            room: Some(rejoined_room.room.clone()),
1359            reshared_projects: rejoined_room
1360                .reshared_projects
1361                .iter()
1362                .map(|project| proto::ResharedProject {
1363                    id: project.id.to_proto(),
1364                    collaborators: project
1365                        .collaborators
1366                        .iter()
1367                        .map(|collaborator| collaborator.to_proto())
1368                        .collect(),
1369                })
1370                .collect(),
1371            rejoined_projects: rejoined_room
1372                .rejoined_projects
1373                .iter()
1374                .map(|rejoined_project| rejoined_project.to_proto())
1375                .collect(),
1376        })?;
1377        room_updated(&rejoined_room.room, &session.peer);
1378
1379        for project in &rejoined_room.reshared_projects {
1380            for collaborator in &project.collaborators {
1381                session
1382                    .peer
1383                    .send(
1384                        collaborator.connection_id,
1385                        proto::UpdateProjectCollaborator {
1386                            project_id: project.id.to_proto(),
1387                            old_peer_id: Some(project.old_connection_id.into()),
1388                            new_peer_id: Some(session.connection_id.into()),
1389                        },
1390                    )
1391                    .trace_err();
1392            }
1393
1394            broadcast(
1395                Some(session.connection_id),
1396                project
1397                    .collaborators
1398                    .iter()
1399                    .map(|collaborator| collaborator.connection_id),
1400                |connection_id| {
1401                    session.peer.forward_send(
1402                        session.connection_id,
1403                        connection_id,
1404                        proto::UpdateProject {
1405                            project_id: project.id.to_proto(),
1406                            worktrees: project.worktrees.clone(),
1407                        },
1408                    )
1409                },
1410            );
1411        }
1412
1413        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1414
1415        let rejoined_room = rejoined_room.into_inner();
1416
1417        room = rejoined_room.room;
1418        channel = rejoined_room.channel;
1419    }
1420
1421    if let Some(channel) = channel {
1422        channel_updated(
1423            &channel,
1424            &room,
1425            &session.peer,
1426            &*session.connection_pool().await,
1427        );
1428    }
1429
1430    update_user_contacts(session.user_id(), &session).await?;
1431    Ok(())
1432}
1433
1434fn notify_rejoined_projects(
1435    rejoined_projects: &mut Vec<RejoinedProject>,
1436    session: &Session,
1437) -> Result<()> {
1438    for project in rejoined_projects.iter() {
1439        for collaborator in &project.collaborators {
1440            session
1441                .peer
1442                .send(
1443                    collaborator.connection_id,
1444                    proto::UpdateProjectCollaborator {
1445                        project_id: project.id.to_proto(),
1446                        old_peer_id: Some(project.old_connection_id.into()),
1447                        new_peer_id: Some(session.connection_id.into()),
1448                    },
1449                )
1450                .trace_err();
1451        }
1452    }
1453
1454    for project in rejoined_projects {
1455        for worktree in mem::take(&mut project.worktrees) {
1456            // Stream this worktree's entries.
1457            let message = proto::UpdateWorktree {
1458                project_id: project.id.to_proto(),
1459                worktree_id: worktree.id,
1460                abs_path: worktree.abs_path.clone(),
1461                root_name: worktree.root_name,
1462                updated_entries: worktree.updated_entries,
1463                removed_entries: worktree.removed_entries,
1464                scan_id: worktree.scan_id,
1465                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1466                updated_repositories: worktree.updated_repositories,
1467                removed_repositories: worktree.removed_repositories,
1468            };
1469            for update in proto::split_worktree_update(message) {
1470                session.peer.send(session.connection_id, update)?;
1471            }
1472
1473            // Stream this worktree's diagnostics.
1474            for summary in worktree.diagnostic_summaries {
1475                session.peer.send(
1476                    session.connection_id,
1477                    proto::UpdateDiagnosticSummary {
1478                        project_id: project.id.to_proto(),
1479                        worktree_id: worktree.id,
1480                        summary: Some(summary),
1481                    },
1482                )?;
1483            }
1484
1485            for settings_file in worktree.settings_files {
1486                session.peer.send(
1487                    session.connection_id,
1488                    proto::UpdateWorktreeSettings {
1489                        project_id: project.id.to_proto(),
1490                        worktree_id: worktree.id,
1491                        path: settings_file.path,
1492                        content: Some(settings_file.content),
1493                        kind: Some(settings_file.kind.to_proto().into()),
1494                    },
1495                )?;
1496            }
1497        }
1498
1499        for repository in mem::take(&mut project.updated_repositories) {
1500            for update in split_repository_update(repository) {
1501                session.peer.send(session.connection_id, update)?;
1502            }
1503        }
1504
1505        for id in mem::take(&mut project.removed_repositories) {
1506            session.peer.send(
1507                session.connection_id,
1508                proto::RemoveRepository {
1509                    project_id: project.id.to_proto(),
1510                    id,
1511                },
1512            )?;
1513        }
1514    }
1515
1516    Ok(())
1517}
1518
1519/// leave room disconnects from the room.
1520async fn leave_room(
1521    _: proto::LeaveRoom,
1522    response: Response<proto::LeaveRoom>,
1523    session: Session,
1524) -> Result<()> {
1525    leave_room_for_session(&session, session.connection_id).await?;
1526    response.send(proto::Ack {})?;
1527    Ok(())
1528}
1529
1530/// Updates the permissions of someone else in the room.
1531async fn set_room_participant_role(
1532    request: proto::SetRoomParticipantRole,
1533    response: Response<proto::SetRoomParticipantRole>,
1534    session: Session,
1535) -> Result<()> {
1536    let user_id = UserId::from_proto(request.user_id);
1537    let role = ChannelRole::from(request.role());
1538
1539    let (livekit_room, can_publish) = {
1540        let room = session
1541            .db()
1542            .await
1543            .set_room_participant_role(
1544                session.user_id(),
1545                RoomId::from_proto(request.room_id),
1546                user_id,
1547                role,
1548            )
1549            .await?;
1550
1551        let livekit_room = room.livekit_room.clone();
1552        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1553        room_updated(&room, &session.peer);
1554        (livekit_room, can_publish)
1555    };
1556
1557    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1558        live_kit
1559            .update_participant(
1560                livekit_room.clone(),
1561                request.user_id.to_string(),
1562                livekit_api::proto::ParticipantPermission {
1563                    can_subscribe: true,
1564                    can_publish,
1565                    can_publish_data: can_publish,
1566                    hidden: false,
1567                    recorder: false,
1568                },
1569            )
1570            .await
1571            .trace_err();
1572    }
1573
1574    response.send(proto::Ack {})?;
1575    Ok(())
1576}
1577
1578/// Call someone else into the current room
1579async fn call(
1580    request: proto::Call,
1581    response: Response<proto::Call>,
1582    session: Session,
1583) -> Result<()> {
1584    let room_id = RoomId::from_proto(request.room_id);
1585    let calling_user_id = session.user_id();
1586    let calling_connection_id = session.connection_id;
1587    let called_user_id = UserId::from_proto(request.called_user_id);
1588    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1589    if !session
1590        .db()
1591        .await
1592        .has_contact(calling_user_id, called_user_id)
1593        .await?
1594    {
1595        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1596    }
1597
1598    let incoming_call = {
1599        let (room, incoming_call) = &mut *session
1600            .db()
1601            .await
1602            .call(
1603                room_id,
1604                calling_user_id,
1605                calling_connection_id,
1606                called_user_id,
1607                initial_project_id,
1608            )
1609            .await?;
1610        room_updated(room, &session.peer);
1611        mem::take(incoming_call)
1612    };
1613    update_user_contacts(called_user_id, &session).await?;
1614
1615    let mut calls = session
1616        .connection_pool()
1617        .await
1618        .user_connection_ids(called_user_id)
1619        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1620        .collect::<FuturesUnordered<_>>();
1621
1622    while let Some(call_response) = calls.next().await {
1623        match call_response.as_ref() {
1624            Ok(_) => {
1625                response.send(proto::Ack {})?;
1626                return Ok(());
1627            }
1628            Err(_) => {
1629                call_response.trace_err();
1630            }
1631        }
1632    }
1633
1634    {
1635        let room = session
1636            .db()
1637            .await
1638            .call_failed(room_id, called_user_id)
1639            .await?;
1640        room_updated(&room, &session.peer);
1641    }
1642    update_user_contacts(called_user_id, &session).await?;
1643
1644    Err(anyhow!("failed to ring user"))?
1645}
1646
1647/// Cancel an outgoing call.
1648async fn cancel_call(
1649    request: proto::CancelCall,
1650    response: Response<proto::CancelCall>,
1651    session: Session,
1652) -> Result<()> {
1653    let called_user_id = UserId::from_proto(request.called_user_id);
1654    let room_id = RoomId::from_proto(request.room_id);
1655    {
1656        let room = session
1657            .db()
1658            .await
1659            .cancel_call(room_id, session.connection_id, called_user_id)
1660            .await?;
1661        room_updated(&room, &session.peer);
1662    }
1663
1664    for connection_id in session
1665        .connection_pool()
1666        .await
1667        .user_connection_ids(called_user_id)
1668    {
1669        session
1670            .peer
1671            .send(
1672                connection_id,
1673                proto::CallCanceled {
1674                    room_id: room_id.to_proto(),
1675                },
1676            )
1677            .trace_err();
1678    }
1679    response.send(proto::Ack {})?;
1680
1681    update_user_contacts(called_user_id, &session).await?;
1682    Ok(())
1683}
1684
1685/// Decline an incoming call.
1686async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1687    let room_id = RoomId::from_proto(message.room_id);
1688    {
1689        let room = session
1690            .db()
1691            .await
1692            .decline_call(Some(room_id), session.user_id())
1693            .await?
1694            .ok_or_else(|| anyhow!("failed to decline call"))?;
1695        room_updated(&room, &session.peer);
1696    }
1697
1698    for connection_id in session
1699        .connection_pool()
1700        .await
1701        .user_connection_ids(session.user_id())
1702    {
1703        session
1704            .peer
1705            .send(
1706                connection_id,
1707                proto::CallCanceled {
1708                    room_id: room_id.to_proto(),
1709                },
1710            )
1711            .trace_err();
1712    }
1713    update_user_contacts(session.user_id(), &session).await?;
1714    Ok(())
1715}
1716
1717/// Updates other participants in the room with your current location.
1718async fn update_participant_location(
1719    request: proto::UpdateParticipantLocation,
1720    response: Response<proto::UpdateParticipantLocation>,
1721    session: Session,
1722) -> Result<()> {
1723    let room_id = RoomId::from_proto(request.room_id);
1724    let location = request
1725        .location
1726        .ok_or_else(|| anyhow!("invalid location"))?;
1727
1728    let db = session.db().await;
1729    let room = db
1730        .update_room_participant_location(room_id, session.connection_id, location)
1731        .await?;
1732
1733    room_updated(&room, &session.peer);
1734    response.send(proto::Ack {})?;
1735    Ok(())
1736}
1737
1738/// Share a project into the room.
1739async fn share_project(
1740    request: proto::ShareProject,
1741    response: Response<proto::ShareProject>,
1742    session: Session,
1743) -> Result<()> {
1744    let (project_id, room) = &*session
1745        .db()
1746        .await
1747        .share_project(
1748            RoomId::from_proto(request.room_id),
1749            session.connection_id,
1750            &request.worktrees,
1751            request.is_ssh_project,
1752        )
1753        .await?;
1754    response.send(proto::ShareProjectResponse {
1755        project_id: project_id.to_proto(),
1756    })?;
1757    room_updated(room, &session.peer);
1758
1759    Ok(())
1760}
1761
1762/// Unshare a project from the room.
1763async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1764    let project_id = ProjectId::from_proto(message.project_id);
1765    unshare_project_internal(project_id, session.connection_id, &session).await
1766}
1767
1768async fn unshare_project_internal(
1769    project_id: ProjectId,
1770    connection_id: ConnectionId,
1771    session: &Session,
1772) -> Result<()> {
1773    let delete = {
1774        let room_guard = session
1775            .db()
1776            .await
1777            .unshare_project(project_id, connection_id)
1778            .await?;
1779
1780        let (delete, room, guest_connection_ids) = &*room_guard;
1781
1782        let message = proto::UnshareProject {
1783            project_id: project_id.to_proto(),
1784        };
1785
1786        broadcast(
1787            Some(connection_id),
1788            guest_connection_ids.iter().copied(),
1789            |conn_id| session.peer.send(conn_id, message.clone()),
1790        );
1791        if let Some(room) = room {
1792            room_updated(room, &session.peer);
1793        }
1794
1795        *delete
1796    };
1797
1798    if delete {
1799        let db = session.db().await;
1800        db.delete_project(project_id).await?;
1801    }
1802
1803    Ok(())
1804}
1805
1806/// Join someone elses shared project.
1807async fn join_project(
1808    request: proto::JoinProject,
1809    response: Response<proto::JoinProject>,
1810    session: Session,
1811) -> Result<()> {
1812    let project_id = ProjectId::from_proto(request.project_id);
1813
1814    tracing::info!(%project_id, "join project");
1815
1816    let db = session.db().await;
1817    let (project, replica_id) = &mut *db
1818        .join_project(project_id, session.connection_id, session.user_id())
1819        .await?;
1820    drop(db);
1821    tracing::info!(%project_id, "join remote project");
1822    join_project_internal(response, session, project, replica_id)
1823}
1824
1825trait JoinProjectInternalResponse {
1826    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1827}
1828impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1829    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1830        Response::<proto::JoinProject>::send(self, result)
1831    }
1832}
1833
1834fn join_project_internal(
1835    response: impl JoinProjectInternalResponse,
1836    session: Session,
1837    project: &mut Project,
1838    replica_id: &ReplicaId,
1839) -> Result<()> {
1840    let collaborators = project
1841        .collaborators
1842        .iter()
1843        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1844        .map(|collaborator| collaborator.to_proto())
1845        .collect::<Vec<_>>();
1846    let project_id = project.id;
1847    let guest_user_id = session.user_id();
1848
1849    let worktrees = project
1850        .worktrees
1851        .iter()
1852        .map(|(id, worktree)| proto::WorktreeMetadata {
1853            id: *id,
1854            root_name: worktree.root_name.clone(),
1855            visible: worktree.visible,
1856            abs_path: worktree.abs_path.clone(),
1857        })
1858        .collect::<Vec<_>>();
1859
1860    let add_project_collaborator = proto::AddProjectCollaborator {
1861        project_id: project_id.to_proto(),
1862        collaborator: Some(proto::Collaborator {
1863            peer_id: Some(session.connection_id.into()),
1864            replica_id: replica_id.0 as u32,
1865            user_id: guest_user_id.to_proto(),
1866            is_host: false,
1867        }),
1868    };
1869
1870    for collaborator in &collaborators {
1871        session
1872            .peer
1873            .send(
1874                collaborator.peer_id.unwrap().into(),
1875                add_project_collaborator.clone(),
1876            )
1877            .trace_err();
1878    }
1879
1880    // First, we send the metadata associated with each worktree.
1881    response.send(proto::JoinProjectResponse {
1882        project_id: project.id.0 as u64,
1883        worktrees: worktrees.clone(),
1884        replica_id: replica_id.0 as u32,
1885        collaborators: collaborators.clone(),
1886        language_servers: project.language_servers.clone(),
1887        role: project.role.into(),
1888    })?;
1889
1890    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1891        // Stream this worktree's entries.
1892        let message = proto::UpdateWorktree {
1893            project_id: project_id.to_proto(),
1894            worktree_id,
1895            abs_path: worktree.abs_path.clone(),
1896            root_name: worktree.root_name,
1897            updated_entries: worktree.entries,
1898            removed_entries: Default::default(),
1899            scan_id: worktree.scan_id,
1900            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1901            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1902            removed_repositories: Default::default(),
1903        };
1904        for update in proto::split_worktree_update(message) {
1905            session.peer.send(session.connection_id, update.clone())?;
1906        }
1907
1908        // Stream this worktree's diagnostics.
1909        for summary in worktree.diagnostic_summaries {
1910            session.peer.send(
1911                session.connection_id,
1912                proto::UpdateDiagnosticSummary {
1913                    project_id: project_id.to_proto(),
1914                    worktree_id: worktree.id,
1915                    summary: Some(summary),
1916                },
1917            )?;
1918        }
1919
1920        for settings_file in worktree.settings_files {
1921            session.peer.send(
1922                session.connection_id,
1923                proto::UpdateWorktreeSettings {
1924                    project_id: project_id.to_proto(),
1925                    worktree_id: worktree.id,
1926                    path: settings_file.path,
1927                    content: Some(settings_file.content),
1928                    kind: Some(settings_file.kind.to_proto() as i32),
1929                },
1930            )?;
1931        }
1932    }
1933
1934    for repository in mem::take(&mut project.repositories) {
1935        for update in split_repository_update(repository) {
1936            session.peer.send(session.connection_id, update)?;
1937        }
1938    }
1939
1940    for language_server in &project.language_servers {
1941        session.peer.send(
1942            session.connection_id,
1943            proto::UpdateLanguageServer {
1944                project_id: project_id.to_proto(),
1945                language_server_id: language_server.id,
1946                variant: Some(
1947                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1948                        proto::LspDiskBasedDiagnosticsUpdated {},
1949                    ),
1950                ),
1951            },
1952        )?;
1953    }
1954
1955    Ok(())
1956}
1957
1958/// Leave someone elses shared project.
1959async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1960    let sender_id = session.connection_id;
1961    let project_id = ProjectId::from_proto(request.project_id);
1962    let db = session.db().await;
1963
1964    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1965    tracing::info!(
1966        %project_id,
1967        "leave project"
1968    );
1969
1970    project_left(project, &session);
1971    if let Some(room) = room {
1972        room_updated(room, &session.peer);
1973    }
1974
1975    Ok(())
1976}
1977
1978/// Updates other participants with changes to the project
1979async fn update_project(
1980    request: proto::UpdateProject,
1981    response: Response<proto::UpdateProject>,
1982    session: Session,
1983) -> Result<()> {
1984    let project_id = ProjectId::from_proto(request.project_id);
1985    let (room, guest_connection_ids) = &*session
1986        .db()
1987        .await
1988        .update_project(project_id, session.connection_id, &request.worktrees)
1989        .await?;
1990    broadcast(
1991        Some(session.connection_id),
1992        guest_connection_ids.iter().copied(),
1993        |connection_id| {
1994            session
1995                .peer
1996                .forward_send(session.connection_id, connection_id, request.clone())
1997        },
1998    );
1999    if let Some(room) = room {
2000        room_updated(room, &session.peer);
2001    }
2002    response.send(proto::Ack {})?;
2003
2004    Ok(())
2005}
2006
2007/// Updates other participants with changes to the worktree
2008async fn update_worktree(
2009    request: proto::UpdateWorktree,
2010    response: Response<proto::UpdateWorktree>,
2011    session: Session,
2012) -> Result<()> {
2013    let guest_connection_ids = session
2014        .db()
2015        .await
2016        .update_worktree(&request, session.connection_id)
2017        .await?;
2018
2019    broadcast(
2020        Some(session.connection_id),
2021        guest_connection_ids.iter().copied(),
2022        |connection_id| {
2023            session
2024                .peer
2025                .forward_send(session.connection_id, connection_id, request.clone())
2026        },
2027    );
2028    response.send(proto::Ack {})?;
2029    Ok(())
2030}
2031
2032async fn update_repository(
2033    request: proto::UpdateRepository,
2034    response: Response<proto::UpdateRepository>,
2035    session: Session,
2036) -> Result<()> {
2037    let guest_connection_ids = session
2038        .db()
2039        .await
2040        .update_repository(&request, session.connection_id)
2041        .await?;
2042
2043    broadcast(
2044        Some(session.connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |connection_id| {
2047            session
2048                .peer
2049                .forward_send(session.connection_id, connection_id, request.clone())
2050        },
2051    );
2052    response.send(proto::Ack {})?;
2053    Ok(())
2054}
2055
2056async fn remove_repository(
2057    request: proto::RemoveRepository,
2058    response: Response<proto::RemoveRepository>,
2059    session: Session,
2060) -> Result<()> {
2061    let guest_connection_ids = session
2062        .db()
2063        .await
2064        .remove_repository(&request, session.connection_id)
2065        .await?;
2066
2067    broadcast(
2068        Some(session.connection_id),
2069        guest_connection_ids.iter().copied(),
2070        |connection_id| {
2071            session
2072                .peer
2073                .forward_send(session.connection_id, connection_id, request.clone())
2074        },
2075    );
2076    response.send(proto::Ack {})?;
2077    Ok(())
2078}
2079
2080/// Updates other participants with changes to the diagnostics
2081async fn update_diagnostic_summary(
2082    message: proto::UpdateDiagnosticSummary,
2083    session: Session,
2084) -> Result<()> {
2085    let guest_connection_ids = session
2086        .db()
2087        .await
2088        .update_diagnostic_summary(&message, session.connection_id)
2089        .await?;
2090
2091    broadcast(
2092        Some(session.connection_id),
2093        guest_connection_ids.iter().copied(),
2094        |connection_id| {
2095            session
2096                .peer
2097                .forward_send(session.connection_id, connection_id, message.clone())
2098        },
2099    );
2100
2101    Ok(())
2102}
2103
2104/// Updates other participants with changes to the worktree settings
2105async fn update_worktree_settings(
2106    message: proto::UpdateWorktreeSettings,
2107    session: Session,
2108) -> Result<()> {
2109    let guest_connection_ids = session
2110        .db()
2111        .await
2112        .update_worktree_settings(&message, session.connection_id)
2113        .await?;
2114
2115    broadcast(
2116        Some(session.connection_id),
2117        guest_connection_ids.iter().copied(),
2118        |connection_id| {
2119            session
2120                .peer
2121                .forward_send(session.connection_id, connection_id, message.clone())
2122        },
2123    );
2124
2125    Ok(())
2126}
2127
2128/// Notify other participants that a language server has started.
2129async fn start_language_server(
2130    request: proto::StartLanguageServer,
2131    session: Session,
2132) -> Result<()> {
2133    let guest_connection_ids = session
2134        .db()
2135        .await
2136        .start_language_server(&request, session.connection_id)
2137        .await?;
2138
2139    broadcast(
2140        Some(session.connection_id),
2141        guest_connection_ids.iter().copied(),
2142        |connection_id| {
2143            session
2144                .peer
2145                .forward_send(session.connection_id, connection_id, request.clone())
2146        },
2147    );
2148    Ok(())
2149}
2150
2151/// Notify other participants that a language server has changed.
2152async fn update_language_server(
2153    request: proto::UpdateLanguageServer,
2154    session: Session,
2155) -> Result<()> {
2156    let project_id = ProjectId::from_proto(request.project_id);
2157    let project_connection_ids = session
2158        .db()
2159        .await
2160        .project_connection_ids(project_id, session.connection_id, true)
2161        .await?;
2162    broadcast(
2163        Some(session.connection_id),
2164        project_connection_ids.iter().copied(),
2165        |connection_id| {
2166            session
2167                .peer
2168                .forward_send(session.connection_id, connection_id, request.clone())
2169        },
2170    );
2171    Ok(())
2172}
2173
2174/// forward a project request to the host. These requests should be read only
2175/// as guests are allowed to send them.
2176async fn forward_read_only_project_request<T>(
2177    request: T,
2178    response: Response<T>,
2179    session: Session,
2180) -> Result<()>
2181where
2182    T: EntityMessage + RequestMessage,
2183{
2184    let project_id = ProjectId::from_proto(request.remote_entity_id());
2185    let host_connection_id = session
2186        .db()
2187        .await
2188        .host_for_read_only_project_request(project_id, session.connection_id)
2189        .await?;
2190    let payload = session
2191        .peer
2192        .forward_request(session.connection_id, host_connection_id, request)
2193        .await?;
2194    response.send(payload)?;
2195    Ok(())
2196}
2197
2198async fn forward_find_search_candidates_request(
2199    request: proto::FindSearchCandidates,
2200    response: Response<proto::FindSearchCandidates>,
2201    session: Session,
2202) -> Result<()> {
2203    let project_id = ProjectId::from_proto(request.remote_entity_id());
2204    let host_connection_id = session
2205        .db()
2206        .await
2207        .host_for_read_only_project_request(project_id, session.connection_id)
2208        .await?;
2209    let payload = session
2210        .peer
2211        .forward_request(session.connection_id, host_connection_id, request)
2212        .await?;
2213    response.send(payload)?;
2214    Ok(())
2215}
2216
2217/// forward a project request to the host. These requests are disallowed
2218/// for guests.
2219async fn forward_mutating_project_request<T>(
2220    request: T,
2221    response: Response<T>,
2222    session: Session,
2223) -> Result<()>
2224where
2225    T: EntityMessage + RequestMessage,
2226{
2227    let project_id = ProjectId::from_proto(request.remote_entity_id());
2228
2229    let host_connection_id = session
2230        .db()
2231        .await
2232        .host_for_mutating_project_request(project_id, session.connection_id)
2233        .await?;
2234    let payload = session
2235        .peer
2236        .forward_request(session.connection_id, host_connection_id, request)
2237        .await?;
2238    response.send(payload)?;
2239    Ok(())
2240}
2241
2242/// Notify other participants that a new buffer has been created
2243async fn create_buffer_for_peer(
2244    request: proto::CreateBufferForPeer,
2245    session: Session,
2246) -> Result<()> {
2247    session
2248        .db()
2249        .await
2250        .check_user_is_project_host(
2251            ProjectId::from_proto(request.project_id),
2252            session.connection_id,
2253        )
2254        .await?;
2255    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2256    session
2257        .peer
2258        .forward_send(session.connection_id, peer_id.into(), request)?;
2259    Ok(())
2260}
2261
2262/// Notify other participants that a buffer has been updated. This is
2263/// allowed for guests as long as the update is limited to selections.
2264async fn update_buffer(
2265    request: proto::UpdateBuffer,
2266    response: Response<proto::UpdateBuffer>,
2267    session: Session,
2268) -> Result<()> {
2269    let project_id = ProjectId::from_proto(request.project_id);
2270    let mut capability = Capability::ReadOnly;
2271
2272    for op in request.operations.iter() {
2273        match op.variant {
2274            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2275            Some(_) => capability = Capability::ReadWrite,
2276        }
2277    }
2278
2279    let host = {
2280        let guard = session
2281            .db()
2282            .await
2283            .connections_for_buffer_update(project_id, session.connection_id, capability)
2284            .await?;
2285
2286        let (host, guests) = &*guard;
2287
2288        broadcast(
2289            Some(session.connection_id),
2290            guests.clone(),
2291            |connection_id| {
2292                session
2293                    .peer
2294                    .forward_send(session.connection_id, connection_id, request.clone())
2295            },
2296        );
2297
2298        *host
2299    };
2300
2301    if host != session.connection_id {
2302        session
2303            .peer
2304            .forward_request(session.connection_id, host, request.clone())
2305            .await?;
2306    }
2307
2308    response.send(proto::Ack {})?;
2309    Ok(())
2310}
2311
2312async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2313    let project_id = ProjectId::from_proto(message.project_id);
2314
2315    let operation = message.operation.as_ref().context("invalid operation")?;
2316    let capability = match operation.variant.as_ref() {
2317        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2318            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2319                match buffer_op.variant {
2320                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2321                        Capability::ReadOnly
2322                    }
2323                    _ => Capability::ReadWrite,
2324                }
2325            } else {
2326                Capability::ReadWrite
2327            }
2328        }
2329        Some(_) => Capability::ReadWrite,
2330        None => Capability::ReadOnly,
2331    };
2332
2333    let guard = session
2334        .db()
2335        .await
2336        .connections_for_buffer_update(project_id, session.connection_id, capability)
2337        .await?;
2338
2339    let (host, guests) = &*guard;
2340
2341    broadcast(
2342        Some(session.connection_id),
2343        guests.iter().chain([host]).copied(),
2344        |connection_id| {
2345            session
2346                .peer
2347                .forward_send(session.connection_id, connection_id, message.clone())
2348        },
2349    );
2350
2351    Ok(())
2352}
2353
2354/// Notify other participants that a project has been updated.
2355async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2356    request: T,
2357    session: Session,
2358) -> Result<()> {
2359    let project_id = ProjectId::from_proto(request.remote_entity_id());
2360    let project_connection_ids = session
2361        .db()
2362        .await
2363        .project_connection_ids(project_id, session.connection_id, false)
2364        .await?;
2365
2366    broadcast(
2367        Some(session.connection_id),
2368        project_connection_ids.iter().copied(),
2369        |connection_id| {
2370            session
2371                .peer
2372                .forward_send(session.connection_id, connection_id, request.clone())
2373        },
2374    );
2375    Ok(())
2376}
2377
2378/// Start following another user in a call.
2379async fn follow(
2380    request: proto::Follow,
2381    response: Response<proto::Follow>,
2382    session: Session,
2383) -> Result<()> {
2384    let room_id = RoomId::from_proto(request.room_id);
2385    let project_id = request.project_id.map(ProjectId::from_proto);
2386    let leader_id = request
2387        .leader_id
2388        .ok_or_else(|| anyhow!("invalid leader id"))?
2389        .into();
2390    let follower_id = session.connection_id;
2391
2392    session
2393        .db()
2394        .await
2395        .check_room_participants(room_id, leader_id, session.connection_id)
2396        .await?;
2397
2398    let response_payload = session
2399        .peer
2400        .forward_request(session.connection_id, leader_id, request)
2401        .await?;
2402    response.send(response_payload)?;
2403
2404    if let Some(project_id) = project_id {
2405        let room = session
2406            .db()
2407            .await
2408            .follow(room_id, project_id, leader_id, follower_id)
2409            .await?;
2410        room_updated(&room, &session.peer);
2411    }
2412
2413    Ok(())
2414}
2415
2416/// Stop following another user in a call.
2417async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2418    let room_id = RoomId::from_proto(request.room_id);
2419    let project_id = request.project_id.map(ProjectId::from_proto);
2420    let leader_id = request
2421        .leader_id
2422        .ok_or_else(|| anyhow!("invalid leader id"))?
2423        .into();
2424    let follower_id = session.connection_id;
2425
2426    session
2427        .db()
2428        .await
2429        .check_room_participants(room_id, leader_id, session.connection_id)
2430        .await?;
2431
2432    session
2433        .peer
2434        .forward_send(session.connection_id, leader_id, request)?;
2435
2436    if let Some(project_id) = project_id {
2437        let room = session
2438            .db()
2439            .await
2440            .unfollow(room_id, project_id, leader_id, follower_id)
2441            .await?;
2442        room_updated(&room, &session.peer);
2443    }
2444
2445    Ok(())
2446}
2447
2448/// Notify everyone following you of your current location.
2449async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2450    let room_id = RoomId::from_proto(request.room_id);
2451    let database = session.db.lock().await;
2452
2453    let connection_ids = if let Some(project_id) = request.project_id {
2454        let project_id = ProjectId::from_proto(project_id);
2455        database
2456            .project_connection_ids(project_id, session.connection_id, true)
2457            .await?
2458    } else {
2459        database
2460            .room_connection_ids(room_id, session.connection_id)
2461            .await?
2462    };
2463
2464    // For now, don't send view update messages back to that view's current leader.
2465    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2466        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2467        _ => None,
2468    });
2469
2470    for connection_id in connection_ids.iter().cloned() {
2471        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2472            session
2473                .peer
2474                .forward_send(session.connection_id, connection_id, request.clone())?;
2475        }
2476    }
2477    Ok(())
2478}
2479
2480/// Get public data about users.
2481async fn get_users(
2482    request: proto::GetUsers,
2483    response: Response<proto::GetUsers>,
2484    session: Session,
2485) -> Result<()> {
2486    let user_ids = request
2487        .user_ids
2488        .into_iter()
2489        .map(UserId::from_proto)
2490        .collect();
2491    let users = session
2492        .db()
2493        .await
2494        .get_users_by_ids(user_ids)
2495        .await?
2496        .into_iter()
2497        .map(|user| proto::User {
2498            id: user.id.to_proto(),
2499            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2500            github_login: user.github_login,
2501            email: user.email_address,
2502            name: user.name,
2503        })
2504        .collect();
2505    response.send(proto::UsersResponse { users })?;
2506    Ok(())
2507}
2508
2509/// Search for users (to invite) buy Github login
2510async fn fuzzy_search_users(
2511    request: proto::FuzzySearchUsers,
2512    response: Response<proto::FuzzySearchUsers>,
2513    session: Session,
2514) -> Result<()> {
2515    let query = request.query;
2516    let users = match query.len() {
2517        0 => vec![],
2518        1 | 2 => session
2519            .db()
2520            .await
2521            .get_user_by_github_login(&query)
2522            .await?
2523            .into_iter()
2524            .collect(),
2525        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2526    };
2527    let users = users
2528        .into_iter()
2529        .filter(|user| user.id != session.user_id())
2530        .map(|user| proto::User {
2531            id: user.id.to_proto(),
2532            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2533            github_login: user.github_login,
2534            name: user.name,
2535            email: user.email_address,
2536        })
2537        .collect();
2538    response.send(proto::UsersResponse { users })?;
2539    Ok(())
2540}
2541
2542/// Send a contact request to another user.
2543async fn request_contact(
2544    request: proto::RequestContact,
2545    response: Response<proto::RequestContact>,
2546    session: Session,
2547) -> Result<()> {
2548    let requester_id = session.user_id();
2549    let responder_id = UserId::from_proto(request.responder_id);
2550    if requester_id == responder_id {
2551        return Err(anyhow!("cannot add yourself as a contact"))?;
2552    }
2553
2554    let notifications = session
2555        .db()
2556        .await
2557        .send_contact_request(requester_id, responder_id)
2558        .await?;
2559
2560    // Update outgoing contact requests of requester
2561    let mut update = proto::UpdateContacts::default();
2562    update.outgoing_requests.push(responder_id.to_proto());
2563    for connection_id in session
2564        .connection_pool()
2565        .await
2566        .user_connection_ids(requester_id)
2567    {
2568        session.peer.send(connection_id, update.clone())?;
2569    }
2570
2571    // Update incoming contact requests of responder
2572    let mut update = proto::UpdateContacts::default();
2573    update
2574        .incoming_requests
2575        .push(proto::IncomingContactRequest {
2576            requester_id: requester_id.to_proto(),
2577        });
2578    let connection_pool = session.connection_pool().await;
2579    for connection_id in connection_pool.user_connection_ids(responder_id) {
2580        session.peer.send(connection_id, update.clone())?;
2581    }
2582
2583    send_notifications(&connection_pool, &session.peer, notifications);
2584
2585    response.send(proto::Ack {})?;
2586    Ok(())
2587}
2588
2589/// Accept or decline a contact request
2590async fn respond_to_contact_request(
2591    request: proto::RespondToContactRequest,
2592    response: Response<proto::RespondToContactRequest>,
2593    session: Session,
2594) -> Result<()> {
2595    let responder_id = session.user_id();
2596    let requester_id = UserId::from_proto(request.requester_id);
2597    let db = session.db().await;
2598    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2599        db.dismiss_contact_notification(responder_id, requester_id)
2600            .await?;
2601    } else {
2602        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2603
2604        let notifications = db
2605            .respond_to_contact_request(responder_id, requester_id, accept)
2606            .await?;
2607        let requester_busy = db.is_user_busy(requester_id).await?;
2608        let responder_busy = db.is_user_busy(responder_id).await?;
2609
2610        let pool = session.connection_pool().await;
2611        // Update responder with new contact
2612        let mut update = proto::UpdateContacts::default();
2613        if accept {
2614            update
2615                .contacts
2616                .push(contact_for_user(requester_id, requester_busy, &pool));
2617        }
2618        update
2619            .remove_incoming_requests
2620            .push(requester_id.to_proto());
2621        for connection_id in pool.user_connection_ids(responder_id) {
2622            session.peer.send(connection_id, update.clone())?;
2623        }
2624
2625        // Update requester with new contact
2626        let mut update = proto::UpdateContacts::default();
2627        if accept {
2628            update
2629                .contacts
2630                .push(contact_for_user(responder_id, responder_busy, &pool));
2631        }
2632        update
2633            .remove_outgoing_requests
2634            .push(responder_id.to_proto());
2635
2636        for connection_id in pool.user_connection_ids(requester_id) {
2637            session.peer.send(connection_id, update.clone())?;
2638        }
2639
2640        send_notifications(&pool, &session.peer, notifications);
2641    }
2642
2643    response.send(proto::Ack {})?;
2644    Ok(())
2645}
2646
2647/// Remove a contact.
2648async fn remove_contact(
2649    request: proto::RemoveContact,
2650    response: Response<proto::RemoveContact>,
2651    session: Session,
2652) -> Result<()> {
2653    let requester_id = session.user_id();
2654    let responder_id = UserId::from_proto(request.user_id);
2655    let db = session.db().await;
2656    let (contact_accepted, deleted_notification_id) =
2657        db.remove_contact(requester_id, responder_id).await?;
2658
2659    let pool = session.connection_pool().await;
2660    // Update outgoing contact requests of requester
2661    let mut update = proto::UpdateContacts::default();
2662    if contact_accepted {
2663        update.remove_contacts.push(responder_id.to_proto());
2664    } else {
2665        update
2666            .remove_outgoing_requests
2667            .push(responder_id.to_proto());
2668    }
2669    for connection_id in pool.user_connection_ids(requester_id) {
2670        session.peer.send(connection_id, update.clone())?;
2671    }
2672
2673    // Update incoming contact requests of responder
2674    let mut update = proto::UpdateContacts::default();
2675    if contact_accepted {
2676        update.remove_contacts.push(requester_id.to_proto());
2677    } else {
2678        update
2679            .remove_incoming_requests
2680            .push(requester_id.to_proto());
2681    }
2682    for connection_id in pool.user_connection_ids(responder_id) {
2683        session.peer.send(connection_id, update.clone())?;
2684        if let Some(notification_id) = deleted_notification_id {
2685            session.peer.send(
2686                connection_id,
2687                proto::DeleteNotification {
2688                    notification_id: notification_id.to_proto(),
2689                },
2690            )?;
2691        }
2692    }
2693
2694    response.send(proto::Ack {})?;
2695    Ok(())
2696}
2697
2698fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2699    version.0.minor() < 139
2700}
2701
2702async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2703    let plan = session.current_plan(&session.db().await).await?;
2704
2705    session
2706        .peer
2707        .send(
2708            session.connection_id,
2709            proto::UpdateUserPlan { plan: plan.into() },
2710        )
2711        .trace_err();
2712
2713    Ok(())
2714}
2715
2716async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2717    subscribe_user_to_channels(session.user_id(), &session).await?;
2718    Ok(())
2719}
2720
2721async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2722    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2723    let mut pool = session.connection_pool().await;
2724    for membership in &channels_for_user.channel_memberships {
2725        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2726    }
2727    session.peer.send(
2728        session.connection_id,
2729        build_update_user_channels(&channels_for_user),
2730    )?;
2731    session.peer.send(
2732        session.connection_id,
2733        build_channels_update(channels_for_user),
2734    )?;
2735    Ok(())
2736}
2737
2738/// Creates a new channel.
2739async fn create_channel(
2740    request: proto::CreateChannel,
2741    response: Response<proto::CreateChannel>,
2742    session: Session,
2743) -> Result<()> {
2744    let db = session.db().await;
2745
2746    let parent_id = request.parent_id.map(ChannelId::from_proto);
2747    let (channel, membership) = db
2748        .create_channel(&request.name, parent_id, session.user_id())
2749        .await?;
2750
2751    let root_id = channel.root_id();
2752    let channel = Channel::from_model(channel);
2753
2754    response.send(proto::CreateChannelResponse {
2755        channel: Some(channel.to_proto()),
2756        parent_id: request.parent_id,
2757    })?;
2758
2759    let mut connection_pool = session.connection_pool().await;
2760    if let Some(membership) = membership {
2761        connection_pool.subscribe_to_channel(
2762            membership.user_id,
2763            membership.channel_id,
2764            membership.role,
2765        );
2766        let update = proto::UpdateUserChannels {
2767            channel_memberships: vec![proto::ChannelMembership {
2768                channel_id: membership.channel_id.to_proto(),
2769                role: membership.role.into(),
2770            }],
2771            ..Default::default()
2772        };
2773        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2774            session.peer.send(connection_id, update.clone())?;
2775        }
2776    }
2777
2778    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2779        if !role.can_see_channel(channel.visibility) {
2780            continue;
2781        }
2782
2783        let update = proto::UpdateChannels {
2784            channels: vec![channel.to_proto()],
2785            ..Default::default()
2786        };
2787        session.peer.send(connection_id, update.clone())?;
2788    }
2789
2790    Ok(())
2791}
2792
2793/// Delete a channel
2794async fn delete_channel(
2795    request: proto::DeleteChannel,
2796    response: Response<proto::DeleteChannel>,
2797    session: Session,
2798) -> Result<()> {
2799    let db = session.db().await;
2800
2801    let channel_id = request.channel_id;
2802    let (root_channel, removed_channels) = db
2803        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2804        .await?;
2805    response.send(proto::Ack {})?;
2806
2807    // Notify members of removed channels
2808    let mut update = proto::UpdateChannels::default();
2809    update
2810        .delete_channels
2811        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2812
2813    let connection_pool = session.connection_pool().await;
2814    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2815        session.peer.send(connection_id, update.clone())?;
2816    }
2817
2818    Ok(())
2819}
2820
2821/// Invite someone to join a channel.
2822async fn invite_channel_member(
2823    request: proto::InviteChannelMember,
2824    response: Response<proto::InviteChannelMember>,
2825    session: Session,
2826) -> Result<()> {
2827    let db = session.db().await;
2828    let channel_id = ChannelId::from_proto(request.channel_id);
2829    let invitee_id = UserId::from_proto(request.user_id);
2830    let InviteMemberResult {
2831        channel,
2832        notifications,
2833    } = db
2834        .invite_channel_member(
2835            channel_id,
2836            invitee_id,
2837            session.user_id(),
2838            request.role().into(),
2839        )
2840        .await?;
2841
2842    let update = proto::UpdateChannels {
2843        channel_invitations: vec![channel.to_proto()],
2844        ..Default::default()
2845    };
2846
2847    let connection_pool = session.connection_pool().await;
2848    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2849        session.peer.send(connection_id, update.clone())?;
2850    }
2851
2852    send_notifications(&connection_pool, &session.peer, notifications);
2853
2854    response.send(proto::Ack {})?;
2855    Ok(())
2856}
2857
2858/// remove someone from a channel
2859async fn remove_channel_member(
2860    request: proto::RemoveChannelMember,
2861    response: Response<proto::RemoveChannelMember>,
2862    session: Session,
2863) -> Result<()> {
2864    let db = session.db().await;
2865    let channel_id = ChannelId::from_proto(request.channel_id);
2866    let member_id = UserId::from_proto(request.user_id);
2867
2868    let RemoveChannelMemberResult {
2869        membership_update,
2870        notification_id,
2871    } = db
2872        .remove_channel_member(channel_id, member_id, session.user_id())
2873        .await?;
2874
2875    let mut connection_pool = session.connection_pool().await;
2876    notify_membership_updated(
2877        &mut connection_pool,
2878        membership_update,
2879        member_id,
2880        &session.peer,
2881    );
2882    for connection_id in connection_pool.user_connection_ids(member_id) {
2883        if let Some(notification_id) = notification_id {
2884            session
2885                .peer
2886                .send(
2887                    connection_id,
2888                    proto::DeleteNotification {
2889                        notification_id: notification_id.to_proto(),
2890                    },
2891                )
2892                .trace_err();
2893        }
2894    }
2895
2896    response.send(proto::Ack {})?;
2897    Ok(())
2898}
2899
2900/// Toggle the channel between public and private.
2901/// Care is taken to maintain the invariant that public channels only descend from public channels,
2902/// (though members-only channels can appear at any point in the hierarchy).
2903async fn set_channel_visibility(
2904    request: proto::SetChannelVisibility,
2905    response: Response<proto::SetChannelVisibility>,
2906    session: Session,
2907) -> Result<()> {
2908    let db = session.db().await;
2909    let channel_id = ChannelId::from_proto(request.channel_id);
2910    let visibility = request.visibility().into();
2911
2912    let channel_model = db
2913        .set_channel_visibility(channel_id, visibility, session.user_id())
2914        .await?;
2915    let root_id = channel_model.root_id();
2916    let channel = Channel::from_model(channel_model);
2917
2918    let mut connection_pool = session.connection_pool().await;
2919    for (user_id, role) in connection_pool
2920        .channel_user_ids(root_id)
2921        .collect::<Vec<_>>()
2922        .into_iter()
2923    {
2924        let update = if role.can_see_channel(channel.visibility) {
2925            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2926            proto::UpdateChannels {
2927                channels: vec![channel.to_proto()],
2928                ..Default::default()
2929            }
2930        } else {
2931            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2932            proto::UpdateChannels {
2933                delete_channels: vec![channel.id.to_proto()],
2934                ..Default::default()
2935            }
2936        };
2937
2938        for connection_id in connection_pool.user_connection_ids(user_id) {
2939            session.peer.send(connection_id, update.clone())?;
2940        }
2941    }
2942
2943    response.send(proto::Ack {})?;
2944    Ok(())
2945}
2946
2947/// Alter the role for a user in the channel.
2948async fn set_channel_member_role(
2949    request: proto::SetChannelMemberRole,
2950    response: Response<proto::SetChannelMemberRole>,
2951    session: Session,
2952) -> Result<()> {
2953    let db = session.db().await;
2954    let channel_id = ChannelId::from_proto(request.channel_id);
2955    let member_id = UserId::from_proto(request.user_id);
2956    let result = db
2957        .set_channel_member_role(
2958            channel_id,
2959            session.user_id(),
2960            member_id,
2961            request.role().into(),
2962        )
2963        .await?;
2964
2965    match result {
2966        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2967            let mut connection_pool = session.connection_pool().await;
2968            notify_membership_updated(
2969                &mut connection_pool,
2970                membership_update,
2971                member_id,
2972                &session.peer,
2973            )
2974        }
2975        db::SetMemberRoleResult::InviteUpdated(channel) => {
2976            let update = proto::UpdateChannels {
2977                channel_invitations: vec![channel.to_proto()],
2978                ..Default::default()
2979            };
2980
2981            for connection_id in session
2982                .connection_pool()
2983                .await
2984                .user_connection_ids(member_id)
2985            {
2986                session.peer.send(connection_id, update.clone())?;
2987            }
2988        }
2989    }
2990
2991    response.send(proto::Ack {})?;
2992    Ok(())
2993}
2994
2995/// Change the name of a channel
2996async fn rename_channel(
2997    request: proto::RenameChannel,
2998    response: Response<proto::RenameChannel>,
2999    session: Session,
3000) -> Result<()> {
3001    let db = session.db().await;
3002    let channel_id = ChannelId::from_proto(request.channel_id);
3003    let channel_model = db
3004        .rename_channel(channel_id, session.user_id(), &request.name)
3005        .await?;
3006    let root_id = channel_model.root_id();
3007    let channel = Channel::from_model(channel_model);
3008
3009    response.send(proto::RenameChannelResponse {
3010        channel: Some(channel.to_proto()),
3011    })?;
3012
3013    let connection_pool = session.connection_pool().await;
3014    let update = proto::UpdateChannels {
3015        channels: vec![channel.to_proto()],
3016        ..Default::default()
3017    };
3018    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3019        if role.can_see_channel(channel.visibility) {
3020            session.peer.send(connection_id, update.clone())?;
3021        }
3022    }
3023
3024    Ok(())
3025}
3026
3027/// Move a channel to a new parent.
3028async fn move_channel(
3029    request: proto::MoveChannel,
3030    response: Response<proto::MoveChannel>,
3031    session: Session,
3032) -> Result<()> {
3033    let channel_id = ChannelId::from_proto(request.channel_id);
3034    let to = ChannelId::from_proto(request.to);
3035
3036    let (root_id, channels) = session
3037        .db()
3038        .await
3039        .move_channel(channel_id, to, session.user_id())
3040        .await?;
3041
3042    let connection_pool = session.connection_pool().await;
3043    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3044        let channels = channels
3045            .iter()
3046            .filter_map(|channel| {
3047                if role.can_see_channel(channel.visibility) {
3048                    Some(channel.to_proto())
3049                } else {
3050                    None
3051                }
3052            })
3053            .collect::<Vec<_>>();
3054        if channels.is_empty() {
3055            continue;
3056        }
3057
3058        let update = proto::UpdateChannels {
3059            channels,
3060            ..Default::default()
3061        };
3062
3063        session.peer.send(connection_id, update.clone())?;
3064    }
3065
3066    response.send(Ack {})?;
3067    Ok(())
3068}
3069
3070/// Get the list of channel members
3071async fn get_channel_members(
3072    request: proto::GetChannelMembers,
3073    response: Response<proto::GetChannelMembers>,
3074    session: Session,
3075) -> Result<()> {
3076    let db = session.db().await;
3077    let channel_id = ChannelId::from_proto(request.channel_id);
3078    let limit = if request.limit == 0 {
3079        u16::MAX as u64
3080    } else {
3081        request.limit
3082    };
3083    let (members, users) = db
3084        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3085        .await?;
3086    response.send(proto::GetChannelMembersResponse { members, users })?;
3087    Ok(())
3088}
3089
3090/// Accept or decline a channel invitation.
3091async fn respond_to_channel_invite(
3092    request: proto::RespondToChannelInvite,
3093    response: Response<proto::RespondToChannelInvite>,
3094    session: Session,
3095) -> Result<()> {
3096    let db = session.db().await;
3097    let channel_id = ChannelId::from_proto(request.channel_id);
3098    let RespondToChannelInvite {
3099        membership_update,
3100        notifications,
3101    } = db
3102        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3103        .await?;
3104
3105    let mut connection_pool = session.connection_pool().await;
3106    if let Some(membership_update) = membership_update {
3107        notify_membership_updated(
3108            &mut connection_pool,
3109            membership_update,
3110            session.user_id(),
3111            &session.peer,
3112        );
3113    } else {
3114        let update = proto::UpdateChannels {
3115            remove_channel_invitations: vec![channel_id.to_proto()],
3116            ..Default::default()
3117        };
3118
3119        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3120            session.peer.send(connection_id, update.clone())?;
3121        }
3122    };
3123
3124    send_notifications(&connection_pool, &session.peer, notifications);
3125
3126    response.send(proto::Ack {})?;
3127
3128    Ok(())
3129}
3130
3131/// Join the channels' room
3132async fn join_channel(
3133    request: proto::JoinChannel,
3134    response: Response<proto::JoinChannel>,
3135    session: Session,
3136) -> Result<()> {
3137    let channel_id = ChannelId::from_proto(request.channel_id);
3138    join_channel_internal(channel_id, Box::new(response), session).await
3139}
3140
3141trait JoinChannelInternalResponse {
3142    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3143}
3144impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3145    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3146        Response::<proto::JoinChannel>::send(self, result)
3147    }
3148}
3149impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3150    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3151        Response::<proto::JoinRoom>::send(self, result)
3152    }
3153}
3154
3155async fn join_channel_internal(
3156    channel_id: ChannelId,
3157    response: Box<impl JoinChannelInternalResponse>,
3158    session: Session,
3159) -> Result<()> {
3160    let joined_room = {
3161        let mut db = session.db().await;
3162        // If zed quits without leaving the room, and the user re-opens zed before the
3163        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3164        // room they were in.
3165        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3166            tracing::info!(
3167                stale_connection_id = %connection,
3168                "cleaning up stale connection",
3169            );
3170            drop(db);
3171            leave_room_for_session(&session, connection).await?;
3172            db = session.db().await;
3173        }
3174
3175        let (joined_room, membership_updated, role) = db
3176            .join_channel(channel_id, session.user_id(), session.connection_id)
3177            .await?;
3178
3179        let live_kit_connection_info =
3180            session
3181                .app_state
3182                .livekit_client
3183                .as_ref()
3184                .and_then(|live_kit| {
3185                    let (can_publish, token) = if role == ChannelRole::Guest {
3186                        (
3187                            false,
3188                            live_kit
3189                                .guest_token(
3190                                    &joined_room.room.livekit_room,
3191                                    &session.user_id().to_string(),
3192                                )
3193                                .trace_err()?,
3194                        )
3195                    } else {
3196                        (
3197                            true,
3198                            live_kit
3199                                .room_token(
3200                                    &joined_room.room.livekit_room,
3201                                    &session.user_id().to_string(),
3202                                )
3203                                .trace_err()?,
3204                        )
3205                    };
3206
3207                    Some(LiveKitConnectionInfo {
3208                        server_url: live_kit.url().into(),
3209                        token,
3210                        can_publish,
3211                    })
3212                });
3213
3214        response.send(proto::JoinRoomResponse {
3215            room: Some(joined_room.room.clone()),
3216            channel_id: joined_room
3217                .channel
3218                .as_ref()
3219                .map(|channel| channel.id.to_proto()),
3220            live_kit_connection_info,
3221        })?;
3222
3223        let mut connection_pool = session.connection_pool().await;
3224        if let Some(membership_updated) = membership_updated {
3225            notify_membership_updated(
3226                &mut connection_pool,
3227                membership_updated,
3228                session.user_id(),
3229                &session.peer,
3230            );
3231        }
3232
3233        room_updated(&joined_room.room, &session.peer);
3234
3235        joined_room
3236    };
3237
3238    channel_updated(
3239        &joined_room
3240            .channel
3241            .ok_or_else(|| anyhow!("channel not returned"))?,
3242        &joined_room.room,
3243        &session.peer,
3244        &*session.connection_pool().await,
3245    );
3246
3247    update_user_contacts(session.user_id(), &session).await?;
3248    Ok(())
3249}
3250
3251/// Start editing the channel notes
3252async fn join_channel_buffer(
3253    request: proto::JoinChannelBuffer,
3254    response: Response<proto::JoinChannelBuffer>,
3255    session: Session,
3256) -> Result<()> {
3257    let db = session.db().await;
3258    let channel_id = ChannelId::from_proto(request.channel_id);
3259
3260    let open_response = db
3261        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3262        .await?;
3263
3264    let collaborators = open_response.collaborators.clone();
3265    response.send(open_response)?;
3266
3267    let update = UpdateChannelBufferCollaborators {
3268        channel_id: channel_id.to_proto(),
3269        collaborators: collaborators.clone(),
3270    };
3271    channel_buffer_updated(
3272        session.connection_id,
3273        collaborators
3274            .iter()
3275            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3276        &update,
3277        &session.peer,
3278    );
3279
3280    Ok(())
3281}
3282
3283/// Edit the channel notes
3284async fn update_channel_buffer(
3285    request: proto::UpdateChannelBuffer,
3286    session: Session,
3287) -> Result<()> {
3288    let db = session.db().await;
3289    let channel_id = ChannelId::from_proto(request.channel_id);
3290
3291    let (collaborators, epoch, version) = db
3292        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3293        .await?;
3294
3295    channel_buffer_updated(
3296        session.connection_id,
3297        collaborators.clone(),
3298        &proto::UpdateChannelBuffer {
3299            channel_id: channel_id.to_proto(),
3300            operations: request.operations,
3301        },
3302        &session.peer,
3303    );
3304
3305    let pool = &*session.connection_pool().await;
3306
3307    let non_collaborators =
3308        pool.channel_connection_ids(channel_id)
3309            .filter_map(|(connection_id, _)| {
3310                if collaborators.contains(&connection_id) {
3311                    None
3312                } else {
3313                    Some(connection_id)
3314                }
3315            });
3316
3317    broadcast(None, non_collaborators, |peer_id| {
3318        session.peer.send(
3319            peer_id,
3320            proto::UpdateChannels {
3321                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3322                    channel_id: channel_id.to_proto(),
3323                    epoch: epoch as u64,
3324                    version: version.clone(),
3325                }],
3326                ..Default::default()
3327            },
3328        )
3329    });
3330
3331    Ok(())
3332}
3333
3334/// Rejoin the channel notes after a connection blip
3335async fn rejoin_channel_buffers(
3336    request: proto::RejoinChannelBuffers,
3337    response: Response<proto::RejoinChannelBuffers>,
3338    session: Session,
3339) -> Result<()> {
3340    let db = session.db().await;
3341    let buffers = db
3342        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3343        .await?;
3344
3345    for rejoined_buffer in &buffers {
3346        let collaborators_to_notify = rejoined_buffer
3347            .buffer
3348            .collaborators
3349            .iter()
3350            .filter_map(|c| Some(c.peer_id?.into()));
3351        channel_buffer_updated(
3352            session.connection_id,
3353            collaborators_to_notify,
3354            &proto::UpdateChannelBufferCollaborators {
3355                channel_id: rejoined_buffer.buffer.channel_id,
3356                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3357            },
3358            &session.peer,
3359        );
3360    }
3361
3362    response.send(proto::RejoinChannelBuffersResponse {
3363        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3364    })?;
3365
3366    Ok(())
3367}
3368
3369/// Stop editing the channel notes
3370async fn leave_channel_buffer(
3371    request: proto::LeaveChannelBuffer,
3372    response: Response<proto::LeaveChannelBuffer>,
3373    session: Session,
3374) -> Result<()> {
3375    let db = session.db().await;
3376    let channel_id = ChannelId::from_proto(request.channel_id);
3377
3378    let left_buffer = db
3379        .leave_channel_buffer(channel_id, session.connection_id)
3380        .await?;
3381
3382    response.send(Ack {})?;
3383
3384    channel_buffer_updated(
3385        session.connection_id,
3386        left_buffer.connections,
3387        &proto::UpdateChannelBufferCollaborators {
3388            channel_id: channel_id.to_proto(),
3389            collaborators: left_buffer.collaborators,
3390        },
3391        &session.peer,
3392    );
3393
3394    Ok(())
3395}
3396
3397fn channel_buffer_updated<T: EnvelopedMessage>(
3398    sender_id: ConnectionId,
3399    collaborators: impl IntoIterator<Item = ConnectionId>,
3400    message: &T,
3401    peer: &Peer,
3402) {
3403    broadcast(Some(sender_id), collaborators, |peer_id| {
3404        peer.send(peer_id, message.clone())
3405    });
3406}
3407
3408fn send_notifications(
3409    connection_pool: &ConnectionPool,
3410    peer: &Peer,
3411    notifications: db::NotificationBatch,
3412) {
3413    for (user_id, notification) in notifications {
3414        for connection_id in connection_pool.user_connection_ids(user_id) {
3415            if let Err(error) = peer.send(
3416                connection_id,
3417                proto::AddNotification {
3418                    notification: Some(notification.clone()),
3419                },
3420            ) {
3421                tracing::error!(
3422                    "failed to send notification to {:?} {}",
3423                    connection_id,
3424                    error
3425                );
3426            }
3427        }
3428    }
3429}
3430
3431/// Send a message to the channel
3432async fn send_channel_message(
3433    request: proto::SendChannelMessage,
3434    response: Response<proto::SendChannelMessage>,
3435    session: Session,
3436) -> Result<()> {
3437    // Validate the message body.
3438    let body = request.body.trim().to_string();
3439    if body.len() > MAX_MESSAGE_LEN {
3440        return Err(anyhow!("message is too long"))?;
3441    }
3442    if body.is_empty() {
3443        return Err(anyhow!("message can't be blank"))?;
3444    }
3445
3446    // TODO: adjust mentions if body is trimmed
3447
3448    let timestamp = OffsetDateTime::now_utc();
3449    let nonce = request
3450        .nonce
3451        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3452
3453    let channel_id = ChannelId::from_proto(request.channel_id);
3454    let CreatedChannelMessage {
3455        message_id,
3456        participant_connection_ids,
3457        notifications,
3458    } = session
3459        .db()
3460        .await
3461        .create_channel_message(
3462            channel_id,
3463            session.user_id(),
3464            &body,
3465            &request.mentions,
3466            timestamp,
3467            nonce.clone().into(),
3468            request.reply_to_message_id.map(MessageId::from_proto),
3469        )
3470        .await?;
3471
3472    let message = proto::ChannelMessage {
3473        sender_id: session.user_id().to_proto(),
3474        id: message_id.to_proto(),
3475        body,
3476        mentions: request.mentions,
3477        timestamp: timestamp.unix_timestamp() as u64,
3478        nonce: Some(nonce),
3479        reply_to_message_id: request.reply_to_message_id,
3480        edited_at: None,
3481    };
3482    broadcast(
3483        Some(session.connection_id),
3484        participant_connection_ids.clone(),
3485        |connection| {
3486            session.peer.send(
3487                connection,
3488                proto::ChannelMessageSent {
3489                    channel_id: channel_id.to_proto(),
3490                    message: Some(message.clone()),
3491                },
3492            )
3493        },
3494    );
3495    response.send(proto::SendChannelMessageResponse {
3496        message: Some(message),
3497    })?;
3498
3499    let pool = &*session.connection_pool().await;
3500    let non_participants =
3501        pool.channel_connection_ids(channel_id)
3502            .filter_map(|(connection_id, _)| {
3503                if participant_connection_ids.contains(&connection_id) {
3504                    None
3505                } else {
3506                    Some(connection_id)
3507                }
3508            });
3509    broadcast(None, non_participants, |peer_id| {
3510        session.peer.send(
3511            peer_id,
3512            proto::UpdateChannels {
3513                latest_channel_message_ids: vec![proto::ChannelMessageId {
3514                    channel_id: channel_id.to_proto(),
3515                    message_id: message_id.to_proto(),
3516                }],
3517                ..Default::default()
3518            },
3519        )
3520    });
3521    send_notifications(pool, &session.peer, notifications);
3522
3523    Ok(())
3524}
3525
3526/// Delete a channel message
3527async fn remove_channel_message(
3528    request: proto::RemoveChannelMessage,
3529    response: Response<proto::RemoveChannelMessage>,
3530    session: Session,
3531) -> Result<()> {
3532    let channel_id = ChannelId::from_proto(request.channel_id);
3533    let message_id = MessageId::from_proto(request.message_id);
3534    let (connection_ids, existing_notification_ids) = session
3535        .db()
3536        .await
3537        .remove_channel_message(channel_id, message_id, session.user_id())
3538        .await?;
3539
3540    broadcast(
3541        Some(session.connection_id),
3542        connection_ids,
3543        move |connection| {
3544            session.peer.send(connection, request.clone())?;
3545
3546            for notification_id in &existing_notification_ids {
3547                session.peer.send(
3548                    connection,
3549                    proto::DeleteNotification {
3550                        notification_id: (*notification_id).to_proto(),
3551                    },
3552                )?;
3553            }
3554
3555            Ok(())
3556        },
3557    );
3558    response.send(proto::Ack {})?;
3559    Ok(())
3560}
3561
3562async fn update_channel_message(
3563    request: proto::UpdateChannelMessage,
3564    response: Response<proto::UpdateChannelMessage>,
3565    session: Session,
3566) -> Result<()> {
3567    let channel_id = ChannelId::from_proto(request.channel_id);
3568    let message_id = MessageId::from_proto(request.message_id);
3569    let updated_at = OffsetDateTime::now_utc();
3570    let UpdatedChannelMessage {
3571        message_id,
3572        participant_connection_ids,
3573        notifications,
3574        reply_to_message_id,
3575        timestamp,
3576        deleted_mention_notification_ids,
3577        updated_mention_notifications,
3578    } = session
3579        .db()
3580        .await
3581        .update_channel_message(
3582            channel_id,
3583            message_id,
3584            session.user_id(),
3585            request.body.as_str(),
3586            &request.mentions,
3587            updated_at,
3588        )
3589        .await?;
3590
3591    let nonce = request
3592        .nonce
3593        .clone()
3594        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3595
3596    let message = proto::ChannelMessage {
3597        sender_id: session.user_id().to_proto(),
3598        id: message_id.to_proto(),
3599        body: request.body.clone(),
3600        mentions: request.mentions.clone(),
3601        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3602        nonce: Some(nonce),
3603        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3604        edited_at: Some(updated_at.unix_timestamp() as u64),
3605    };
3606
3607    response.send(proto::Ack {})?;
3608
3609    let pool = &*session.connection_pool().await;
3610    broadcast(
3611        Some(session.connection_id),
3612        participant_connection_ids,
3613        |connection| {
3614            session.peer.send(
3615                connection,
3616                proto::ChannelMessageUpdate {
3617                    channel_id: channel_id.to_proto(),
3618                    message: Some(message.clone()),
3619                },
3620            )?;
3621
3622            for notification_id in &deleted_mention_notification_ids {
3623                session.peer.send(
3624                    connection,
3625                    proto::DeleteNotification {
3626                        notification_id: (*notification_id).to_proto(),
3627                    },
3628                )?;
3629            }
3630
3631            for notification in &updated_mention_notifications {
3632                session.peer.send(
3633                    connection,
3634                    proto::UpdateNotification {
3635                        notification: Some(notification.clone()),
3636                    },
3637                )?;
3638            }
3639
3640            Ok(())
3641        },
3642    );
3643
3644    send_notifications(pool, &session.peer, notifications);
3645
3646    Ok(())
3647}
3648
3649/// Mark a channel message as read
3650async fn acknowledge_channel_message(
3651    request: proto::AckChannelMessage,
3652    session: Session,
3653) -> Result<()> {
3654    let channel_id = ChannelId::from_proto(request.channel_id);
3655    let message_id = MessageId::from_proto(request.message_id);
3656    let notifications = session
3657        .db()
3658        .await
3659        .observe_channel_message(channel_id, session.user_id(), message_id)
3660        .await?;
3661    send_notifications(
3662        &*session.connection_pool().await,
3663        &session.peer,
3664        notifications,
3665    );
3666    Ok(())
3667}
3668
3669/// Mark a buffer version as synced
3670async fn acknowledge_buffer_version(
3671    request: proto::AckBufferOperation,
3672    session: Session,
3673) -> Result<()> {
3674    let buffer_id = BufferId::from_proto(request.buffer_id);
3675    session
3676        .db()
3677        .await
3678        .observe_buffer_version(
3679            buffer_id,
3680            session.user_id(),
3681            request.epoch as i32,
3682            &request.version,
3683        )
3684        .await?;
3685    Ok(())
3686}
3687
3688async fn count_language_model_tokens(
3689    request: proto::CountLanguageModelTokens,
3690    response: Response<proto::CountLanguageModelTokens>,
3691    session: Session,
3692    config: &Config,
3693) -> Result<()> {
3694    authorize_access_to_legacy_llm_endpoints(&session).await?;
3695
3696    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3697        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3698        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3699    };
3700
3701    session
3702        .app_state
3703        .rate_limiter
3704        .check(&*rate_limit, session.user_id())
3705        .await?;
3706
3707    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3708        Some(proto::LanguageModelProvider::Google) => {
3709            let api_key = config
3710                .google_ai_api_key
3711                .as_ref()
3712                .context("no Google AI API key configured on the server")?;
3713            google_ai::count_tokens(
3714                session.http_client.as_ref(),
3715                google_ai::API_URL,
3716                api_key,
3717                serde_json::from_str(&request.request)?,
3718            )
3719            .await?
3720        }
3721        _ => return Err(anyhow!("unsupported provider"))?,
3722    };
3723
3724    response.send(proto::CountLanguageModelTokensResponse {
3725        token_count: result.total_tokens as u32,
3726    })?;
3727
3728    Ok(())
3729}
3730
3731struct ZedProCountLanguageModelTokensRateLimit;
3732
3733impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3734    fn capacity(&self) -> usize {
3735        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3736            .ok()
3737            .and_then(|v| v.parse().ok())
3738            .unwrap_or(600) // Picked arbitrarily
3739    }
3740
3741    fn refill_duration(&self) -> chrono::Duration {
3742        chrono::Duration::hours(1)
3743    }
3744
3745    fn db_name(&self) -> &'static str {
3746        "zed-pro:count-language-model-tokens"
3747    }
3748}
3749
3750struct FreeCountLanguageModelTokensRateLimit;
3751
3752impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3753    fn capacity(&self) -> usize {
3754        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3755            .ok()
3756            .and_then(|v| v.parse().ok())
3757            .unwrap_or(600 / 10) // Picked arbitrarily
3758    }
3759
3760    fn refill_duration(&self) -> chrono::Duration {
3761        chrono::Duration::hours(1)
3762    }
3763
3764    fn db_name(&self) -> &'static str {
3765        "free:count-language-model-tokens"
3766    }
3767}
3768
3769struct ZedProComputeEmbeddingsRateLimit;
3770
3771impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3772    fn capacity(&self) -> usize {
3773        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3774            .ok()
3775            .and_then(|v| v.parse().ok())
3776            .unwrap_or(5000) // Picked arbitrarily
3777    }
3778
3779    fn refill_duration(&self) -> chrono::Duration {
3780        chrono::Duration::hours(1)
3781    }
3782
3783    fn db_name(&self) -> &'static str {
3784        "zed-pro:compute-embeddings"
3785    }
3786}
3787
3788struct FreeComputeEmbeddingsRateLimit;
3789
3790impl RateLimit for FreeComputeEmbeddingsRateLimit {
3791    fn capacity(&self) -> usize {
3792        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3793            .ok()
3794            .and_then(|v| v.parse().ok())
3795            .unwrap_or(5000 / 10) // Picked arbitrarily
3796    }
3797
3798    fn refill_duration(&self) -> chrono::Duration {
3799        chrono::Duration::hours(1)
3800    }
3801
3802    fn db_name(&self) -> &'static str {
3803        "free:compute-embeddings"
3804    }
3805}
3806
3807async fn compute_embeddings(
3808    request: proto::ComputeEmbeddings,
3809    response: Response<proto::ComputeEmbeddings>,
3810    session: Session,
3811    api_key: Option<Arc<str>>,
3812) -> Result<()> {
3813    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3814    authorize_access_to_legacy_llm_endpoints(&session).await?;
3815
3816    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3817        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3818        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3819    };
3820
3821    session
3822        .app_state
3823        .rate_limiter
3824        .check(&*rate_limit, session.user_id())
3825        .await?;
3826
3827    let embeddings = match request.model.as_str() {
3828        "openai/text-embedding-3-small" => {
3829            open_ai::embed(
3830                session.http_client.as_ref(),
3831                OPEN_AI_API_URL,
3832                &api_key,
3833                OpenAiEmbeddingModel::TextEmbedding3Small,
3834                request.texts.iter().map(|text| text.as_str()),
3835            )
3836            .await?
3837        }
3838        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3839    };
3840
3841    let embeddings = request
3842        .texts
3843        .iter()
3844        .map(|text| {
3845            let mut hasher = sha2::Sha256::new();
3846            hasher.update(text.as_bytes());
3847            let result = hasher.finalize();
3848            result.to_vec()
3849        })
3850        .zip(
3851            embeddings
3852                .data
3853                .into_iter()
3854                .map(|embedding| embedding.embedding),
3855        )
3856        .collect::<HashMap<_, _>>();
3857
3858    let db = session.db().await;
3859    db.save_embeddings(&request.model, &embeddings)
3860        .await
3861        .context("failed to save embeddings")
3862        .trace_err();
3863
3864    response.send(proto::ComputeEmbeddingsResponse {
3865        embeddings: embeddings
3866            .into_iter()
3867            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3868            .collect(),
3869    })?;
3870    Ok(())
3871}
3872
3873async fn get_cached_embeddings(
3874    request: proto::GetCachedEmbeddings,
3875    response: Response<proto::GetCachedEmbeddings>,
3876    session: Session,
3877) -> Result<()> {
3878    authorize_access_to_legacy_llm_endpoints(&session).await?;
3879
3880    let db = session.db().await;
3881    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3882
3883    response.send(proto::GetCachedEmbeddingsResponse {
3884        embeddings: embeddings
3885            .into_iter()
3886            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3887            .collect(),
3888    })?;
3889    Ok(())
3890}
3891
3892/// This is leftover from before the LLM service.
3893///
3894/// The endpoints protected by this check will be moved there eventually.
3895async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3896    if session.is_staff() {
3897        Ok(())
3898    } else {
3899        Err(anyhow!("permission denied"))?
3900    }
3901}
3902
3903/// Get a Supermaven API key for the user
3904async fn get_supermaven_api_key(
3905    _request: proto::GetSupermavenApiKey,
3906    response: Response<proto::GetSupermavenApiKey>,
3907    session: Session,
3908) -> Result<()> {
3909    let user_id: String = session.user_id().to_string();
3910    if !session.is_staff() {
3911        return Err(anyhow!("supermaven not enabled for this account"))?;
3912    }
3913
3914    let email = session
3915        .email()
3916        .ok_or_else(|| anyhow!("user must have an email"))?;
3917
3918    let supermaven_admin_api = session
3919        .supermaven_client
3920        .as_ref()
3921        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3922
3923    let result = supermaven_admin_api
3924        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3925        .await?;
3926
3927    response.send(proto::GetSupermavenApiKeyResponse {
3928        api_key: result.api_key,
3929    })?;
3930
3931    Ok(())
3932}
3933
3934/// Start receiving chat updates for a channel
3935async fn join_channel_chat(
3936    request: proto::JoinChannelChat,
3937    response: Response<proto::JoinChannelChat>,
3938    session: Session,
3939) -> Result<()> {
3940    let channel_id = ChannelId::from_proto(request.channel_id);
3941
3942    let db = session.db().await;
3943    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3944        .await?;
3945    let messages = db
3946        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3947        .await?;
3948    response.send(proto::JoinChannelChatResponse {
3949        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3950        messages,
3951    })?;
3952    Ok(())
3953}
3954
3955/// Stop receiving chat updates for a channel
3956async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3957    let channel_id = ChannelId::from_proto(request.channel_id);
3958    session
3959        .db()
3960        .await
3961        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3962        .await?;
3963    Ok(())
3964}
3965
3966/// Retrieve the chat history for a channel
3967async fn get_channel_messages(
3968    request: proto::GetChannelMessages,
3969    response: Response<proto::GetChannelMessages>,
3970    session: Session,
3971) -> Result<()> {
3972    let channel_id = ChannelId::from_proto(request.channel_id);
3973    let messages = session
3974        .db()
3975        .await
3976        .get_channel_messages(
3977            channel_id,
3978            session.user_id(),
3979            MESSAGE_COUNT_PER_PAGE,
3980            Some(MessageId::from_proto(request.before_message_id)),
3981        )
3982        .await?;
3983    response.send(proto::GetChannelMessagesResponse {
3984        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3985        messages,
3986    })?;
3987    Ok(())
3988}
3989
3990/// Retrieve specific chat messages
3991async fn get_channel_messages_by_id(
3992    request: proto::GetChannelMessagesById,
3993    response: Response<proto::GetChannelMessagesById>,
3994    session: Session,
3995) -> Result<()> {
3996    let message_ids = request
3997        .message_ids
3998        .iter()
3999        .map(|id| MessageId::from_proto(*id))
4000        .collect::<Vec<_>>();
4001    let messages = session
4002        .db()
4003        .await
4004        .get_channel_messages_by_id(session.user_id(), &message_ids)
4005        .await?;
4006    response.send(proto::GetChannelMessagesResponse {
4007        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4008        messages,
4009    })?;
4010    Ok(())
4011}
4012
4013/// Retrieve the current users notifications
4014async fn get_notifications(
4015    request: proto::GetNotifications,
4016    response: Response<proto::GetNotifications>,
4017    session: Session,
4018) -> Result<()> {
4019    let notifications = session
4020        .db()
4021        .await
4022        .get_notifications(
4023            session.user_id(),
4024            NOTIFICATION_COUNT_PER_PAGE,
4025            request.before_id.map(db::NotificationId::from_proto),
4026        )
4027        .await?;
4028    response.send(proto::GetNotificationsResponse {
4029        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4030        notifications,
4031    })?;
4032    Ok(())
4033}
4034
4035/// Mark notifications as read
4036async fn mark_notification_as_read(
4037    request: proto::MarkNotificationRead,
4038    response: Response<proto::MarkNotificationRead>,
4039    session: Session,
4040) -> Result<()> {
4041    let database = &session.db().await;
4042    let notifications = database
4043        .mark_notification_as_read_by_id(
4044            session.user_id(),
4045            NotificationId::from_proto(request.notification_id),
4046        )
4047        .await?;
4048    send_notifications(
4049        &*session.connection_pool().await,
4050        &session.peer,
4051        notifications,
4052    );
4053    response.send(proto::Ack {})?;
4054    Ok(())
4055}
4056
4057/// Get the current users information
4058async fn get_private_user_info(
4059    _request: proto::GetPrivateUserInfo,
4060    response: Response<proto::GetPrivateUserInfo>,
4061    session: Session,
4062) -> Result<()> {
4063    let db = session.db().await;
4064
4065    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4066    let user = db
4067        .get_user_by_id(session.user_id())
4068        .await?
4069        .ok_or_else(|| anyhow!("user not found"))?;
4070    let flags = db.get_user_flags(session.user_id()).await?;
4071
4072    response.send(proto::GetPrivateUserInfoResponse {
4073        metrics_id,
4074        staff: user.admin,
4075        flags,
4076        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4077    })?;
4078    Ok(())
4079}
4080
4081/// Accept the terms of service (tos) on behalf of the current user
4082async fn accept_terms_of_service(
4083    _request: proto::AcceptTermsOfService,
4084    response: Response<proto::AcceptTermsOfService>,
4085    session: Session,
4086) -> Result<()> {
4087    let db = session.db().await;
4088
4089    let accepted_tos_at = Utc::now();
4090    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4091        .await?;
4092
4093    response.send(proto::AcceptTermsOfServiceResponse {
4094        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4095    })?;
4096    Ok(())
4097}
4098
4099/// The minimum account age an account must have in order to use the LLM service.
4100pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4101
4102async fn get_llm_api_token(
4103    _request: proto::GetLlmToken,
4104    response: Response<proto::GetLlmToken>,
4105    session: Session,
4106) -> Result<()> {
4107    let db = session.db().await;
4108
4109    let flags = db.get_user_flags(session.user_id()).await?;
4110    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4111
4112    if !session.is_staff() && !has_language_models_feature_flag {
4113        Err(anyhow!("permission denied"))?
4114    }
4115
4116    let user_id = session.user_id();
4117    let user = db
4118        .get_user_by_id(user_id)
4119        .await?
4120        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4121
4122    if user.accepted_tos_at.is_none() {
4123        Err(anyhow!("terms of service not accepted"))?
4124    }
4125
4126    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4127    let billing_preferences = db.get_billing_preferences(user.id).await?;
4128
4129    let token = LlmTokenClaims::create(
4130        &user,
4131        session.is_staff(),
4132        billing_preferences,
4133        &flags,
4134        has_llm_subscription,
4135        session.current_plan(&db).await?,
4136        session.system_id.clone(),
4137        &session.app_state.config,
4138    )?;
4139    response.send(proto::GetLlmTokenResponse { token })?;
4140    Ok(())
4141}
4142
4143fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4144    let message = match message {
4145        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4146        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4147        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4148        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4149        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4150            code: frame.code.into(),
4151            reason: frame.reason,
4152        })),
4153        // We should never receive a frame while reading the message, according
4154        // to the `tungstenite` maintainers:
4155        //
4156        // > It cannot occur when you read messages from the WebSocket, but it
4157        // > can be used when you want to send the raw frames (e.g. you want to
4158        // > send the frames to the WebSocket without composing the full message first).
4159        // >
4160        // > — https://github.com/snapview/tungstenite-rs/issues/268
4161        TungsteniteMessage::Frame(_) => {
4162            bail!("received an unexpected frame while reading the message")
4163        }
4164    };
4165
4166    Ok(message)
4167}
4168
4169fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4170    match message {
4171        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4172        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4173        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4174        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4175        AxumMessage::Close(frame) => {
4176            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4177                code: frame.code.into(),
4178                reason: frame.reason,
4179            }))
4180        }
4181    }
4182}
4183
4184fn notify_membership_updated(
4185    connection_pool: &mut ConnectionPool,
4186    result: MembershipUpdated,
4187    user_id: UserId,
4188    peer: &Peer,
4189) {
4190    for membership in &result.new_channels.channel_memberships {
4191        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4192    }
4193    for channel_id in &result.removed_channels {
4194        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4195    }
4196
4197    let user_channels_update = proto::UpdateUserChannels {
4198        channel_memberships: result
4199            .new_channels
4200            .channel_memberships
4201            .iter()
4202            .map(|cm| proto::ChannelMembership {
4203                channel_id: cm.channel_id.to_proto(),
4204                role: cm.role.into(),
4205            })
4206            .collect(),
4207        ..Default::default()
4208    };
4209
4210    let mut update = build_channels_update(result.new_channels);
4211    update.delete_channels = result
4212        .removed_channels
4213        .into_iter()
4214        .map(|id| id.to_proto())
4215        .collect();
4216    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4217
4218    for connection_id in connection_pool.user_connection_ids(user_id) {
4219        peer.send(connection_id, user_channels_update.clone())
4220            .trace_err();
4221        peer.send(connection_id, update.clone()).trace_err();
4222    }
4223}
4224
4225fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4226    proto::UpdateUserChannels {
4227        channel_memberships: channels
4228            .channel_memberships
4229            .iter()
4230            .map(|m| proto::ChannelMembership {
4231                channel_id: m.channel_id.to_proto(),
4232                role: m.role.into(),
4233            })
4234            .collect(),
4235        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4236        observed_channel_message_id: channels.observed_channel_messages.clone(),
4237    }
4238}
4239
4240fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4241    let mut update = proto::UpdateChannels::default();
4242
4243    for channel in channels.channels {
4244        update.channels.push(channel.to_proto());
4245    }
4246
4247    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4248    update.latest_channel_message_ids = channels.latest_channel_messages;
4249
4250    for (channel_id, participants) in channels.channel_participants {
4251        update
4252            .channel_participants
4253            .push(proto::ChannelParticipants {
4254                channel_id: channel_id.to_proto(),
4255                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4256            });
4257    }
4258
4259    for channel in channels.invited_channels {
4260        update.channel_invitations.push(channel.to_proto());
4261    }
4262
4263    update
4264}
4265
4266fn build_initial_contacts_update(
4267    contacts: Vec<db::Contact>,
4268    pool: &ConnectionPool,
4269) -> proto::UpdateContacts {
4270    let mut update = proto::UpdateContacts::default();
4271
4272    for contact in contacts {
4273        match contact {
4274            db::Contact::Accepted { user_id, busy } => {
4275                update.contacts.push(contact_for_user(user_id, busy, pool));
4276            }
4277            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4278            db::Contact::Incoming { user_id } => {
4279                update
4280                    .incoming_requests
4281                    .push(proto::IncomingContactRequest {
4282                        requester_id: user_id.to_proto(),
4283                    })
4284            }
4285        }
4286    }
4287
4288    update
4289}
4290
4291fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4292    proto::Contact {
4293        user_id: user_id.to_proto(),
4294        online: pool.is_user_online(user_id),
4295        busy,
4296    }
4297}
4298
4299fn room_updated(room: &proto::Room, peer: &Peer) {
4300    broadcast(
4301        None,
4302        room.participants
4303            .iter()
4304            .filter_map(|participant| Some(participant.peer_id?.into())),
4305        |peer_id| {
4306            peer.send(
4307                peer_id,
4308                proto::RoomUpdated {
4309                    room: Some(room.clone()),
4310                },
4311            )
4312        },
4313    );
4314}
4315
4316fn channel_updated(
4317    channel: &db::channel::Model,
4318    room: &proto::Room,
4319    peer: &Peer,
4320    pool: &ConnectionPool,
4321) {
4322    let participants = room
4323        .participants
4324        .iter()
4325        .map(|p| p.user_id)
4326        .collect::<Vec<_>>();
4327
4328    broadcast(
4329        None,
4330        pool.channel_connection_ids(channel.root_id())
4331            .filter_map(|(channel_id, role)| {
4332                role.can_see_channel(channel.visibility)
4333                    .then_some(channel_id)
4334            }),
4335        |peer_id| {
4336            peer.send(
4337                peer_id,
4338                proto::UpdateChannels {
4339                    channel_participants: vec![proto::ChannelParticipants {
4340                        channel_id: channel.id.to_proto(),
4341                        participant_user_ids: participants.clone(),
4342                    }],
4343                    ..Default::default()
4344                },
4345            )
4346        },
4347    );
4348}
4349
4350async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4351    let db = session.db().await;
4352
4353    let contacts = db.get_contacts(user_id).await?;
4354    let busy = db.is_user_busy(user_id).await?;
4355
4356    let pool = session.connection_pool().await;
4357    let updated_contact = contact_for_user(user_id, busy, &pool);
4358    for contact in contacts {
4359        if let db::Contact::Accepted {
4360            user_id: contact_user_id,
4361            ..
4362        } = contact
4363        {
4364            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4365                session
4366                    .peer
4367                    .send(
4368                        contact_conn_id,
4369                        proto::UpdateContacts {
4370                            contacts: vec![updated_contact.clone()],
4371                            remove_contacts: Default::default(),
4372                            incoming_requests: Default::default(),
4373                            remove_incoming_requests: Default::default(),
4374                            outgoing_requests: Default::default(),
4375                            remove_outgoing_requests: Default::default(),
4376                        },
4377                    )
4378                    .trace_err();
4379            }
4380        }
4381    }
4382    Ok(())
4383}
4384
4385async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4386    let mut contacts_to_update = HashSet::default();
4387
4388    let room_id;
4389    let canceled_calls_to_user_ids;
4390    let livekit_room;
4391    let delete_livekit_room;
4392    let room;
4393    let channel;
4394
4395    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4396        contacts_to_update.insert(session.user_id());
4397
4398        for project in left_room.left_projects.values() {
4399            project_left(project, session);
4400        }
4401
4402        room_id = RoomId::from_proto(left_room.room.id);
4403        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4404        livekit_room = mem::take(&mut left_room.room.livekit_room);
4405        delete_livekit_room = left_room.deleted;
4406        room = mem::take(&mut left_room.room);
4407        channel = mem::take(&mut left_room.channel);
4408
4409        room_updated(&room, &session.peer);
4410    } else {
4411        return Ok(());
4412    }
4413
4414    if let Some(channel) = channel {
4415        channel_updated(
4416            &channel,
4417            &room,
4418            &session.peer,
4419            &*session.connection_pool().await,
4420        );
4421    }
4422
4423    {
4424        let pool = session.connection_pool().await;
4425        for canceled_user_id in canceled_calls_to_user_ids {
4426            for connection_id in pool.user_connection_ids(canceled_user_id) {
4427                session
4428                    .peer
4429                    .send(
4430                        connection_id,
4431                        proto::CallCanceled {
4432                            room_id: room_id.to_proto(),
4433                        },
4434                    )
4435                    .trace_err();
4436            }
4437            contacts_to_update.insert(canceled_user_id);
4438        }
4439    }
4440
4441    for contact_user_id in contacts_to_update {
4442        update_user_contacts(contact_user_id, session).await?;
4443    }
4444
4445    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4446        live_kit
4447            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4448            .await
4449            .trace_err();
4450
4451        if delete_livekit_room {
4452            live_kit.delete_room(livekit_room).await.trace_err();
4453        }
4454    }
4455
4456    Ok(())
4457}
4458
4459async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4460    let left_channel_buffers = session
4461        .db()
4462        .await
4463        .leave_channel_buffers(session.connection_id)
4464        .await?;
4465
4466    for left_buffer in left_channel_buffers {
4467        channel_buffer_updated(
4468            session.connection_id,
4469            left_buffer.connections,
4470            &proto::UpdateChannelBufferCollaborators {
4471                channel_id: left_buffer.channel_id.to_proto(),
4472                collaborators: left_buffer.collaborators,
4473            },
4474            &session.peer,
4475        );
4476    }
4477
4478    Ok(())
4479}
4480
4481fn project_left(project: &db::LeftProject, session: &Session) {
4482    for connection_id in &project.connection_ids {
4483        if project.should_unshare {
4484            session
4485                .peer
4486                .send(
4487                    *connection_id,
4488                    proto::UnshareProject {
4489                        project_id: project.id.to_proto(),
4490                    },
4491                )
4492                .trace_err();
4493        } else {
4494            session
4495                .peer
4496                .send(
4497                    *connection_id,
4498                    proto::RemoveProjectCollaborator {
4499                        project_id: project.id.to_proto(),
4500                        peer_id: Some(session.connection_id.into()),
4501                    },
4502                )
4503                .trace_err();
4504        }
4505    }
4506}
4507
4508pub trait ResultExt {
4509    type Ok;
4510
4511    fn trace_err(self) -> Option<Self::Ok>;
4512}
4513
4514impl<T, E> ResultExt for Result<T, E>
4515where
4516    E: std::fmt::Debug,
4517{
4518    type Ok = T;
4519
4520    #[track_caller]
4521    fn trace_err(self) -> Option<T> {
4522        match self {
4523            Ok(value) => Some(value),
4524            Err(error) => {
4525                tracing::error!("{:?}", error);
4526                None
4527            }
4528        }
4529    }
4530}