rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81use self::connection_pool::VersionedMessage;
  82
  83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  84
  85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  87
  88const MESSAGE_COUNT_PER_PAGE: usize = 100;
  89const MAX_MESSAGE_LEN: usize = 1024;
  90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  94
  95struct Response<R> {
  96    peer: Arc<Peer>,
  97    receipt: Receipt<R>,
  98    responded: Arc<AtomicBool>,
  99}
 100
 101impl<R: RequestMessage> Response<R> {
 102    fn send(self, payload: R::Response) -> Result<()> {
 103        self.responded.store(true, SeqCst);
 104        self.peer.respond(self.receipt, payload)?;
 105        Ok(())
 106    }
 107}
 108
 109#[derive(Clone, Debug)]
 110pub enum Principal {
 111    User(User),
 112    Impersonated { user: User, admin: User },
 113    DevServer(dev_server::Model),
 114}
 115
 116impl Principal {
 117    fn update_span(&self, span: &tracing::Span) {
 118        match &self {
 119            Principal::User(user) => {
 120                span.record("user_id", &user.id.0);
 121                span.record("login", &user.github_login);
 122            }
 123            Principal::Impersonated { user, admin } => {
 124                span.record("user_id", &user.id.0);
 125                span.record("login", &user.github_login);
 126                span.record("impersonator", &admin.github_login);
 127            }
 128            Principal::DevServer(dev_server) => {
 129                span.record("dev_server_id", &dev_server.id.0);
 130            }
 131        }
 132    }
 133}
 134
 135#[derive(Clone)]
 136struct Session {
 137    principal: Principal,
 138    connection_id: ConnectionId,
 139    db: Arc<tokio::sync::Mutex<DbHandle>>,
 140    peer: Arc<Peer>,
 141    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 142    app_state: Arc<AppState>,
 143    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 144    http_client: Arc<IsahcHttpClient>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn for_dev_server(self) -> Option<DevServerSession> {
 176        DevServerSession::new(self)
 177    }
 178
 179    fn user_id(&self) -> Option<UserId> {
 180        match &self.principal {
 181            Principal::User(user) => Some(user.id),
 182            Principal::Impersonated { user, .. } => Some(user.id),
 183            Principal::DevServer(_) => None,
 184        }
 185    }
 186
 187    fn is_staff(&self) -> bool {
 188        match &self.principal {
 189            Principal::User(user) => user.admin,
 190            Principal::Impersonated { .. } => true,
 191            Principal::DevServer(_) => false,
 192        }
 193    }
 194
 195    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 196        if self.is_staff() {
 197            return Ok(proto::Plan::ZedPro);
 198        }
 199
 200        let Some(user_id) = self.user_id() else {
 201            return Ok(proto::Plan::Free);
 202        };
 203
 204        if db.has_active_billing_subscription(user_id).await? {
 205            Ok(proto::Plan::ZedPro)
 206        } else {
 207            Ok(proto::Plan::Free)
 208        }
 209    }
 210
 211    fn dev_server_id(&self) -> Option<DevServerId> {
 212        match &self.principal {
 213            Principal::User(_) | Principal::Impersonated { .. } => None,
 214            Principal::DevServer(dev_server) => Some(dev_server.id),
 215        }
 216    }
 217
 218    fn principal_id(&self) -> PrincipalId {
 219        match &self.principal {
 220            Principal::User(user) => PrincipalId::UserId(user.id),
 221            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 222            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 223        }
 224    }
 225}
 226
 227impl Debug for Session {
 228    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 229        let mut result = f.debug_struct("Session");
 230        match &self.principal {
 231            Principal::User(user) => {
 232                result.field("user", &user.github_login);
 233            }
 234            Principal::Impersonated { user, admin } => {
 235                result.field("user", &user.github_login);
 236                result.field("impersonator", &admin.github_login);
 237            }
 238            Principal::DevServer(dev_server) => {
 239                result.field("dev_server", &dev_server.id);
 240            }
 241        }
 242        result.field("connection_id", &self.connection_id).finish()
 243    }
 244}
 245
 246struct UserSession(Session);
 247
 248impl UserSession {
 249    pub fn new(s: Session) -> Option<Self> {
 250        s.user_id().map(|_| UserSession(s))
 251    }
 252    pub fn user_id(&self) -> UserId {
 253        self.0.user_id().unwrap()
 254    }
 255
 256    pub fn email(&self) -> Option<String> {
 257        match &self.0.principal {
 258            Principal::User(user) => user.email_address.clone(),
 259            Principal::Impersonated { user, .. } => user.email_address.clone(),
 260            Principal::DevServer(..) => None,
 261        }
 262    }
 263}
 264
 265impl Deref for UserSession {
 266    type Target = Session;
 267
 268    fn deref(&self) -> &Self::Target {
 269        &self.0
 270    }
 271}
 272impl DerefMut for UserSession {
 273    fn deref_mut(&mut self) -> &mut Self::Target {
 274        &mut self.0
 275    }
 276}
 277
 278struct DevServerSession(Session);
 279
 280impl DevServerSession {
 281    pub fn new(s: Session) -> Option<Self> {
 282        s.dev_server_id().map(|_| DevServerSession(s))
 283    }
 284    pub fn dev_server_id(&self) -> DevServerId {
 285        self.0.dev_server_id().unwrap()
 286    }
 287
 288    fn dev_server(&self) -> &dev_server::Model {
 289        match &self.0.principal {
 290            Principal::DevServer(dev_server) => dev_server,
 291            _ => unreachable!(),
 292        }
 293    }
 294}
 295
 296impl Deref for DevServerSession {
 297    type Target = Session;
 298
 299    fn deref(&self) -> &Self::Target {
 300        &self.0
 301    }
 302}
 303impl DerefMut for DevServerSession {
 304    fn deref_mut(&mut self) -> &mut Self::Target {
 305        &mut self.0
 306    }
 307}
 308
 309fn user_handler<M: RequestMessage, Fut>(
 310    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 311) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 312where
 313    Fut: Send + Future<Output = Result<()>>,
 314{
 315    let handler = Arc::new(handler);
 316    move |message, response, session| {
 317        let handler = handler.clone();
 318        Box::pin(async move {
 319            if let Some(user_session) = session.for_user() {
 320                Ok(handler(message, response, user_session).await?)
 321            } else {
 322                Err(Error::Internal(anyhow!(
 323                    "must be a user to call {}",
 324                    M::NAME
 325                )))
 326            }
 327        })
 328    }
 329}
 330
 331fn dev_server_handler<M: RequestMessage, Fut>(
 332    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 333) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 334where
 335    Fut: Send + Future<Output = Result<()>>,
 336{
 337    let handler = Arc::new(handler);
 338    move |message, response, session| {
 339        let handler = handler.clone();
 340        Box::pin(async move {
 341            if let Some(dev_server_session) = session.for_dev_server() {
 342                Ok(handler(message, response, dev_server_session).await?)
 343            } else {
 344                Err(Error::Internal(anyhow!(
 345                    "must be a dev server to call {}",
 346                    M::NAME
 347                )))
 348            }
 349        })
 350    }
 351}
 352
 353fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 354    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 355) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 356where
 357    InnertRetFut: Send + Future<Output = Result<()>>,
 358{
 359    let handler = Arc::new(handler);
 360    move |message, session| {
 361        let handler = handler.clone();
 362        Box::pin(async move {
 363            if let Some(user_session) = session.for_user() {
 364                Ok(handler(message, user_session).await?)
 365            } else {
 366                Err(Error::Internal(anyhow!(
 367                    "must be a user to call {}",
 368                    M::NAME
 369                )))
 370            }
 371        })
 372    }
 373}
 374
 375struct DbHandle(Arc<Database>);
 376
 377impl Deref for DbHandle {
 378    type Target = Database;
 379
 380    fn deref(&self) -> &Self::Target {
 381        self.0.as_ref()
 382    }
 383}
 384
 385pub struct Server {
 386    id: parking_lot::Mutex<ServerId>,
 387    peer: Arc<Peer>,
 388    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 389    app_state: Arc<AppState>,
 390    handlers: HashMap<TypeId, MessageHandler>,
 391    teardown: watch::Sender<bool>,
 392}
 393
 394pub(crate) struct ConnectionPoolGuard<'a> {
 395    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 396    _not_send: PhantomData<Rc<()>>,
 397}
 398
 399#[derive(Serialize)]
 400pub struct ServerSnapshot<'a> {
 401    peer: &'a Peer,
 402    #[serde(serialize_with = "serialize_deref")]
 403    connection_pool: ConnectionPoolGuard<'a>,
 404}
 405
 406pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 407where
 408    S: Serializer,
 409    T: Deref<Target = U>,
 410    U: Serialize,
 411{
 412    Serialize::serialize(value.deref(), serializer)
 413}
 414
 415impl Server {
 416    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 417        let mut server = Self {
 418            id: parking_lot::Mutex::new(id),
 419            peer: Peer::new(id.0 as u32),
 420            app_state: app_state.clone(),
 421            connection_pool: Default::default(),
 422            handlers: Default::default(),
 423            teardown: watch::channel(false).0,
 424        };
 425
 426        server
 427            .add_request_handler(ping)
 428            .add_request_handler(user_handler(create_room))
 429            .add_request_handler(user_handler(join_room))
 430            .add_request_handler(user_handler(rejoin_room))
 431            .add_request_handler(user_handler(leave_room))
 432            .add_request_handler(user_handler(set_room_participant_role))
 433            .add_request_handler(user_handler(call))
 434            .add_request_handler(user_handler(cancel_call))
 435            .add_message_handler(user_message_handler(decline_call))
 436            .add_request_handler(user_handler(update_participant_location))
 437            .add_request_handler(user_handler(share_project))
 438            .add_message_handler(unshare_project)
 439            .add_request_handler(user_handler(join_project))
 440            .add_request_handler(user_handler(join_hosted_project))
 441            .add_request_handler(user_handler(rejoin_dev_server_projects))
 442            .add_request_handler(user_handler(create_dev_server_project))
 443            .add_request_handler(user_handler(update_dev_server_project))
 444            .add_request_handler(user_handler(delete_dev_server_project))
 445            .add_request_handler(user_handler(create_dev_server))
 446            .add_request_handler(user_handler(regenerate_dev_server_token))
 447            .add_request_handler(user_handler(rename_dev_server))
 448            .add_request_handler(user_handler(delete_dev_server))
 449            .add_request_handler(user_handler(list_remote_directory))
 450            .add_request_handler(dev_server_handler(share_dev_server_project))
 451            .add_request_handler(dev_server_handler(shutdown_dev_server))
 452            .add_request_handler(dev_server_handler(reconnect_dev_server))
 453            .add_message_handler(user_message_handler(leave_project))
 454            .add_request_handler(update_project)
 455            .add_request_handler(update_worktree)
 456            .add_message_handler(start_language_server)
 457            .add_message_handler(update_language_server)
 458            .add_message_handler(update_diagnostic_summary)
 459            .add_message_handler(update_worktree_settings)
 460            .add_request_handler(user_handler(
 461                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_project_request_for_owner::<proto::TaskTemplates>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetHover>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::GetDefinition>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetTypeDefinition>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetReferences>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::SearchProject>,
 480            ))
 481            .add_request_handler(user_handler(forward_find_search_candidates_request))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::GetProjectSymbols>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::OpenBufferById>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::InlayHints>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::OpenBufferByPath>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCompletions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::GetCodeActions>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::ApplyCodeAction>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::PrepareRename>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::PerformRename>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::ReloadBuffers>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::FormatBuffers>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::CreateProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::RenameProjectEntry>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::CopyProjectEntry>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::OnTypeFormatting>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::BlameBuffer>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::MultiLspQuery>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::RestartLanguageServers>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::LinkedEditingRange>,
 565            ))
 566            .add_message_handler(create_buffer_for_peer)
 567            .add_request_handler(update_buffer)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 573            .add_request_handler(get_users)
 574            .add_request_handler(user_handler(fuzzy_search_users))
 575            .add_request_handler(user_handler(request_contact))
 576            .add_request_handler(user_handler(remove_contact))
 577            .add_request_handler(user_handler(respond_to_contact_request))
 578            .add_message_handler(subscribe_to_channels)
 579            .add_request_handler(user_handler(create_channel))
 580            .add_request_handler(user_handler(delete_channel))
 581            .add_request_handler(user_handler(invite_channel_member))
 582            .add_request_handler(user_handler(remove_channel_member))
 583            .add_request_handler(user_handler(set_channel_member_role))
 584            .add_request_handler(user_handler(set_channel_visibility))
 585            .add_request_handler(user_handler(rename_channel))
 586            .add_request_handler(user_handler(join_channel_buffer))
 587            .add_request_handler(user_handler(leave_channel_buffer))
 588            .add_message_handler(user_message_handler(update_channel_buffer))
 589            .add_request_handler(user_handler(rejoin_channel_buffers))
 590            .add_request_handler(user_handler(get_channel_members))
 591            .add_request_handler(user_handler(respond_to_channel_invite))
 592            .add_request_handler(user_handler(join_channel))
 593            .add_request_handler(user_handler(join_channel_chat))
 594            .add_message_handler(user_message_handler(leave_channel_chat))
 595            .add_request_handler(user_handler(send_channel_message))
 596            .add_request_handler(user_handler(remove_channel_message))
 597            .add_request_handler(user_handler(update_channel_message))
 598            .add_request_handler(user_handler(get_channel_messages))
 599            .add_request_handler(user_handler(get_channel_messages_by_id))
 600            .add_request_handler(user_handler(get_notifications))
 601            .add_request_handler(user_handler(mark_notification_as_read))
 602            .add_request_handler(user_handler(move_channel))
 603            .add_request_handler(user_handler(follow))
 604            .add_message_handler(user_message_handler(unfollow))
 605            .add_message_handler(user_message_handler(update_followers))
 606            .add_request_handler(user_handler(get_private_user_info))
 607            .add_request_handler(user_handler(get_llm_api_token))
 608            .add_request_handler(user_handler(accept_terms_of_service))
 609            .add_message_handler(user_message_handler(acknowledge_channel_message))
 610            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 611            .add_request_handler(user_handler(get_supermaven_api_key))
 612            .add_request_handler(user_handler(
 613                forward_mutating_project_request::<proto::OpenContext>,
 614            ))
 615            .add_request_handler(user_handler(
 616                forward_mutating_project_request::<proto::CreateContext>,
 617            ))
 618            .add_request_handler(user_handler(
 619                forward_mutating_project_request::<proto::SynchronizeContexts>,
 620            ))
 621            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 622            .add_message_handler(update_context)
 623            .add_request_handler({
 624                let app_state = app_state.clone();
 625                move |request, response, session| {
 626                    let app_state = app_state.clone();
 627                    async move {
 628                        count_language_model_tokens(request, response, session, &app_state.config)
 629                            .await
 630                    }
 631                }
 632            })
 633            .add_request_handler({
 634                user_handler(move |request, response, session| {
 635                    get_cached_embeddings(request, response, session)
 636                })
 637            })
 638            .add_request_handler({
 639                let app_state = app_state.clone();
 640                user_handler(move |request, response, session| {
 641                    compute_embeddings(
 642                        request,
 643                        response,
 644                        session,
 645                        app_state.config.openai_api_key.clone(),
 646                    )
 647                })
 648            });
 649
 650        Arc::new(server)
 651    }
 652
 653    pub async fn start(&self) -> Result<()> {
 654        let server_id = *self.id.lock();
 655        let app_state = self.app_state.clone();
 656        let peer = self.peer.clone();
 657        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 658        let pool = self.connection_pool.clone();
 659        let live_kit_client = self.app_state.live_kit_client.clone();
 660
 661        let span = info_span!("start server");
 662        self.app_state.executor.spawn_detached(
 663            async move {
 664                tracing::info!("waiting for cleanup timeout");
 665                timeout.await;
 666                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 667                if let Some((room_ids, channel_ids)) = app_state
 668                    .db
 669                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 670                    .await
 671                    .trace_err()
 672                {
 673                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 674                    tracing::info!(
 675                        stale_channel_buffer_count = channel_ids.len(),
 676                        "retrieved stale channel buffers"
 677                    );
 678
 679                    for channel_id in channel_ids {
 680                        if let Some(refreshed_channel_buffer) = app_state
 681                            .db
 682                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 683                            .await
 684                            .trace_err()
 685                        {
 686                            for connection_id in refreshed_channel_buffer.connection_ids {
 687                                peer.send(
 688                                    connection_id,
 689                                    proto::UpdateChannelBufferCollaborators {
 690                                        channel_id: channel_id.to_proto(),
 691                                        collaborators: refreshed_channel_buffer
 692                                            .collaborators
 693                                            .clone(),
 694                                    },
 695                                )
 696                                .trace_err();
 697                            }
 698                        }
 699                    }
 700
 701                    for room_id in room_ids {
 702                        let mut contacts_to_update = HashSet::default();
 703                        let mut canceled_calls_to_user_ids = Vec::new();
 704                        let mut live_kit_room = String::new();
 705                        let mut delete_live_kit_room = false;
 706
 707                        if let Some(mut refreshed_room) = app_state
 708                            .db
 709                            .clear_stale_room_participants(room_id, server_id)
 710                            .await
 711                            .trace_err()
 712                        {
 713                            tracing::info!(
 714                                room_id = room_id.0,
 715                                new_participant_count = refreshed_room.room.participants.len(),
 716                                "refreshed room"
 717                            );
 718                            room_updated(&refreshed_room.room, &peer);
 719                            if let Some(channel) = refreshed_room.channel.as_ref() {
 720                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 721                            }
 722                            contacts_to_update
 723                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 724                            contacts_to_update
 725                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 726                            canceled_calls_to_user_ids =
 727                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 728                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 729                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 730                        }
 731
 732                        {
 733                            let pool = pool.lock();
 734                            for canceled_user_id in canceled_calls_to_user_ids {
 735                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 736                                    peer.send(
 737                                        connection_id,
 738                                        proto::CallCanceled {
 739                                            room_id: room_id.to_proto(),
 740                                        },
 741                                    )
 742                                    .trace_err();
 743                                }
 744                            }
 745                        }
 746
 747                        for user_id in contacts_to_update {
 748                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 749                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 750                            if let Some((busy, contacts)) = busy.zip(contacts) {
 751                                let pool = pool.lock();
 752                                let updated_contact = contact_for_user(user_id, busy, &pool);
 753                                for contact in contacts {
 754                                    if let db::Contact::Accepted {
 755                                        user_id: contact_user_id,
 756                                        ..
 757                                    } = contact
 758                                    {
 759                                        for contact_conn_id in
 760                                            pool.user_connection_ids(contact_user_id)
 761                                        {
 762                                            peer.send(
 763                                                contact_conn_id,
 764                                                proto::UpdateContacts {
 765                                                    contacts: vec![updated_contact.clone()],
 766                                                    remove_contacts: Default::default(),
 767                                                    incoming_requests: Default::default(),
 768                                                    remove_incoming_requests: Default::default(),
 769                                                    outgoing_requests: Default::default(),
 770                                                    remove_outgoing_requests: Default::default(),
 771                                                },
 772                                            )
 773                                            .trace_err();
 774                                        }
 775                                    }
 776                                }
 777                            }
 778                        }
 779
 780                        if let Some(live_kit) = live_kit_client.as_ref() {
 781                            if delete_live_kit_room {
 782                                live_kit.delete_room(live_kit_room).await.trace_err();
 783                            }
 784                        }
 785                    }
 786                }
 787
 788                app_state
 789                    .db
 790                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 791                    .await
 792                    .trace_err();
 793            }
 794            .instrument(span),
 795        );
 796        Ok(())
 797    }
 798
 799    pub fn teardown(&self) {
 800        self.peer.teardown();
 801        self.connection_pool.lock().reset();
 802        let _ = self.teardown.send(true);
 803    }
 804
 805    #[cfg(test)]
 806    pub fn reset(&self, id: ServerId) {
 807        self.teardown();
 808        *self.id.lock() = id;
 809        self.peer.reset(id.0 as u32);
 810        let _ = self.teardown.send(false);
 811    }
 812
 813    #[cfg(test)]
 814    pub fn id(&self) -> ServerId {
 815        *self.id.lock()
 816    }
 817
 818    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 819    where
 820        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 821        Fut: 'static + Send + Future<Output = Result<()>>,
 822        M: EnvelopedMessage,
 823    {
 824        let prev_handler = self.handlers.insert(
 825            TypeId::of::<M>(),
 826            Box::new(move |envelope, session| {
 827                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 828                let received_at = envelope.received_at;
 829                tracing::info!("message received");
 830                let start_time = Instant::now();
 831                let future = (handler)(*envelope, session);
 832                async move {
 833                    let result = future.await;
 834                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 835                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 836                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 837                    let payload_type = M::NAME;
 838
 839                    match result {
 840                        Err(error) => {
 841                            tracing::error!(
 842                                ?error,
 843                                total_duration_ms,
 844                                processing_duration_ms,
 845                                queue_duration_ms,
 846                                payload_type,
 847                                "error handling message"
 848                            )
 849                        }
 850                        Ok(()) => tracing::info!(
 851                            total_duration_ms,
 852                            processing_duration_ms,
 853                            queue_duration_ms,
 854                            "finished handling message"
 855                        ),
 856                    }
 857                }
 858                .boxed()
 859            }),
 860        );
 861        if prev_handler.is_some() {
 862            panic!("registered a handler for the same message twice");
 863        }
 864        self
 865    }
 866
 867    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 868    where
 869        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 870        Fut: 'static + Send + Future<Output = Result<()>>,
 871        M: EnvelopedMessage,
 872    {
 873        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 874        self
 875    }
 876
 877    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 878    where
 879        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 880        Fut: Send + Future<Output = Result<()>>,
 881        M: RequestMessage,
 882    {
 883        let handler = Arc::new(handler);
 884        self.add_handler(move |envelope, session| {
 885            let receipt = envelope.receipt();
 886            let handler = handler.clone();
 887            async move {
 888                let peer = session.peer.clone();
 889                let responded = Arc::new(AtomicBool::default());
 890                let response = Response {
 891                    peer: peer.clone(),
 892                    responded: responded.clone(),
 893                    receipt,
 894                };
 895                match (handler)(envelope.payload, response, session).await {
 896                    Ok(()) => {
 897                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 898                            Ok(())
 899                        } else {
 900                            Err(anyhow!("handler did not send a response"))?
 901                        }
 902                    }
 903                    Err(error) => {
 904                        let proto_err = match &error {
 905                            Error::Internal(err) => err.to_proto(),
 906                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 907                        };
 908                        peer.respond_with_error(receipt, proto_err)?;
 909                        Err(error)
 910                    }
 911                }
 912            }
 913        })
 914    }
 915
 916    #[allow(clippy::too_many_arguments)]
 917    pub fn handle_connection(
 918        self: &Arc<Self>,
 919        connection: Connection,
 920        address: String,
 921        principal: Principal,
 922        zed_version: ZedVersion,
 923        geoip_country_code: Option<String>,
 924        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 925        executor: Executor,
 926    ) -> impl Future<Output = ()> {
 927        let this = self.clone();
 928        let span = info_span!("handle connection", %address,
 929            connection_id=field::Empty,
 930            user_id=field::Empty,
 931            login=field::Empty,
 932            impersonator=field::Empty,
 933            dev_server_id=field::Empty,
 934            geoip_country_code=field::Empty
 935        );
 936        principal.update_span(&span);
 937        if let Some(country_code) = geoip_country_code.as_ref() {
 938            span.record("geoip_country_code", country_code);
 939        }
 940
 941        let mut teardown = self.teardown.subscribe();
 942        async move {
 943            if *teardown.borrow() {
 944                tracing::error!("server is tearing down");
 945                return
 946            }
 947            let (connection_id, handle_io, mut incoming_rx) = this
 948                .peer
 949                .add_connection(connection, {
 950                    let executor = executor.clone();
 951                    move |duration| executor.sleep(duration)
 952                });
 953            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 954
 955            tracing::info!("connection opened");
 956
 957            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 958            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 959                Ok(http_client) => Arc::new(http_client),
 960                Err(error) => {
 961                    tracing::error!(?error, "failed to create HTTP client");
 962                    return;
 963                }
 964            };
 965
 966            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 967                Some(Arc::new(SupermavenAdminApi::new(
 968                    supermaven_admin_api_key.to_string(),
 969                    http_client.clone(),
 970                )))
 971            } else {
 972                None
 973            };
 974
 975            let session = Session {
 976                principal: principal.clone(),
 977                connection_id,
 978                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 979                peer: this.peer.clone(),
 980                connection_pool: this.connection_pool.clone(),
 981                app_state: this.app_state.clone(),
 982                http_client,
 983                geoip_country_code,
 984                _executor: executor.clone(),
 985                supermaven_client,
 986            };
 987
 988            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 989                tracing::error!(?error, "failed to send initial client update");
 990                return;
 991            }
 992
 993            let handle_io = handle_io.fuse();
 994            futures::pin_mut!(handle_io);
 995
 996            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 997            // This prevents deadlocks when e.g., client A performs a request to client B and
 998            // client B performs a request to client A. If both clients stop processing further
 999            // messages until their respective request completes, they won't have a chance to
1000            // respond to the other client's request and cause a deadlock.
1001            //
1002            // This arrangement ensures we will attempt to process earlier messages first, but fall
1003            // back to processing messages arrived later in the spirit of making progress.
1004            let mut foreground_message_handlers = FuturesUnordered::new();
1005            let concurrent_handlers = Arc::new(Semaphore::new(256));
1006            loop {
1007                let next_message = async {
1008                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1009                    let message = incoming_rx.next().await;
1010                    (permit, message)
1011                }.fuse();
1012                futures::pin_mut!(next_message);
1013                futures::select_biased! {
1014                    _ = teardown.changed().fuse() => return,
1015                    result = handle_io => {
1016                        if let Err(error) = result {
1017                            tracing::error!(?error, "error handling I/O");
1018                        }
1019                        break;
1020                    }
1021                    _ = foreground_message_handlers.next() => {}
1022                    next_message = next_message => {
1023                        let (permit, message) = next_message;
1024                        if let Some(message) = message {
1025                            let type_name = message.payload_type_name();
1026                            // note: we copy all the fields from the parent span so we can query them in the logs.
1027                            // (https://github.com/tokio-rs/tracing/issues/2670).
1028                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1029                                user_id=field::Empty,
1030                                login=field::Empty,
1031                                impersonator=field::Empty,
1032                                dev_server_id=field::Empty
1033                            );
1034                            principal.update_span(&span);
1035                            let span_enter = span.enter();
1036                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1037                                let is_background = message.is_background();
1038                                let handle_message = (handler)(message, session.clone());
1039                                drop(span_enter);
1040
1041                                let handle_message = async move {
1042                                    handle_message.await;
1043                                    drop(permit);
1044                                }.instrument(span);
1045                                if is_background {
1046                                    executor.spawn_detached(handle_message);
1047                                } else {
1048                                    foreground_message_handlers.push(handle_message);
1049                                }
1050                            } else {
1051                                tracing::error!("no message handler");
1052                            }
1053                        } else {
1054                            tracing::info!("connection closed");
1055                            break;
1056                        }
1057                    }
1058                }
1059            }
1060
1061            drop(foreground_message_handlers);
1062            tracing::info!("signing out");
1063            if let Err(error) = connection_lost(session, teardown, executor).await {
1064                tracing::error!(?error, "error signing out");
1065            }
1066
1067        }.instrument(span)
1068    }
1069
1070    async fn send_initial_client_update(
1071        &self,
1072        connection_id: ConnectionId,
1073        principal: &Principal,
1074        zed_version: ZedVersion,
1075        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1076        session: &Session,
1077    ) -> Result<()> {
1078        self.peer.send(
1079            connection_id,
1080            proto::Hello {
1081                peer_id: Some(connection_id.into()),
1082            },
1083        )?;
1084        tracing::info!("sent hello message");
1085        if let Some(send_connection_id) = send_connection_id.take() {
1086            let _ = send_connection_id.send(connection_id);
1087        }
1088
1089        match principal {
1090            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1091                if !user.connected_once {
1092                    self.peer.send(connection_id, proto::ShowContacts {})?;
1093                    self.app_state
1094                        .db
1095                        .set_user_connected_once(user.id, true)
1096                        .await?;
1097                }
1098
1099                update_user_plan(user.id, session).await?;
1100
1101                let (contacts, dev_server_projects) = future::try_join(
1102                    self.app_state.db.get_contacts(user.id),
1103                    self.app_state.db.dev_server_projects_update(user.id),
1104                )
1105                .await?;
1106
1107                {
1108                    let mut pool = self.connection_pool.lock();
1109                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1110                    self.peer.send(
1111                        connection_id,
1112                        build_initial_contacts_update(contacts, &pool),
1113                    )?;
1114                }
1115
1116                if should_auto_subscribe_to_channels(zed_version) {
1117                    subscribe_user_to_channels(user.id, session).await?;
1118                }
1119
1120                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1121
1122                if let Some(incoming_call) =
1123                    self.app_state.db.incoming_call_for_user(user.id).await?
1124                {
1125                    self.peer.send(connection_id, incoming_call)?;
1126                }
1127
1128                update_user_contacts(user.id, &session).await?;
1129            }
1130            Principal::DevServer(dev_server) => {
1131                {
1132                    let mut pool = self.connection_pool.lock();
1133                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1134                    {
1135                        self.peer.send(
1136                            stale_connection_id,
1137                            proto::ShutdownDevServer {
1138                                reason: Some(
1139                                    "another dev server connected with the same token".to_string(),
1140                                ),
1141                            },
1142                        )?;
1143                        pool.remove_connection(stale_connection_id)?;
1144                    };
1145                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1146                }
1147
1148                let projects = self
1149                    .app_state
1150                    .db
1151                    .get_projects_for_dev_server(dev_server.id)
1152                    .await?;
1153                self.peer
1154                    .send(connection_id, proto::DevServerInstructions { projects })?;
1155
1156                let status = self
1157                    .app_state
1158                    .db
1159                    .dev_server_projects_update(dev_server.user_id)
1160                    .await?;
1161                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1162            }
1163        }
1164
1165        Ok(())
1166    }
1167
1168    pub async fn invite_code_redeemed(
1169        self: &Arc<Self>,
1170        inviter_id: UserId,
1171        invitee_id: UserId,
1172    ) -> Result<()> {
1173        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1174            if let Some(code) = &user.invite_code {
1175                let pool = self.connection_pool.lock();
1176                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1177                for connection_id in pool.user_connection_ids(inviter_id) {
1178                    self.peer.send(
1179                        connection_id,
1180                        proto::UpdateContacts {
1181                            contacts: vec![invitee_contact.clone()],
1182                            ..Default::default()
1183                        },
1184                    )?;
1185                    self.peer.send(
1186                        connection_id,
1187                        proto::UpdateInviteInfo {
1188                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1189                            count: user.invite_count as u32,
1190                        },
1191                    )?;
1192                }
1193            }
1194        }
1195        Ok(())
1196    }
1197
1198    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1199        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1200            if let Some(invite_code) = &user.invite_code {
1201                let pool = self.connection_pool.lock();
1202                for connection_id in pool.user_connection_ids(user_id) {
1203                    self.peer.send(
1204                        connection_id,
1205                        proto::UpdateInviteInfo {
1206                            url: format!(
1207                                "{}{}",
1208                                self.app_state.config.invite_link_prefix, invite_code
1209                            ),
1210                            count: user.invite_count as u32,
1211                        },
1212                    )?;
1213                }
1214            }
1215        }
1216        Ok(())
1217    }
1218
1219    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1220        ServerSnapshot {
1221            connection_pool: ConnectionPoolGuard {
1222                guard: self.connection_pool.lock(),
1223                _not_send: PhantomData,
1224            },
1225            peer: &self.peer,
1226        }
1227    }
1228}
1229
1230impl<'a> Deref for ConnectionPoolGuard<'a> {
1231    type Target = ConnectionPool;
1232
1233    fn deref(&self) -> &Self::Target {
1234        &self.guard
1235    }
1236}
1237
1238impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1239    fn deref_mut(&mut self) -> &mut Self::Target {
1240        &mut self.guard
1241    }
1242}
1243
1244impl<'a> Drop for ConnectionPoolGuard<'a> {
1245    fn drop(&mut self) {
1246        #[cfg(test)]
1247        self.check_invariants();
1248    }
1249}
1250
1251fn broadcast<F>(
1252    sender_id: Option<ConnectionId>,
1253    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1254    mut f: F,
1255) where
1256    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1257{
1258    for receiver_id in receiver_ids {
1259        if Some(receiver_id) != sender_id {
1260            if let Err(error) = f(receiver_id) {
1261                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1262            }
1263        }
1264    }
1265}
1266
1267pub struct ProtocolVersion(u32);
1268
1269impl Header for ProtocolVersion {
1270    fn name() -> &'static HeaderName {
1271        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1272        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1273    }
1274
1275    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1276    where
1277        Self: Sized,
1278        I: Iterator<Item = &'i axum::http::HeaderValue>,
1279    {
1280        let version = values
1281            .next()
1282            .ok_or_else(axum::headers::Error::invalid)?
1283            .to_str()
1284            .map_err(|_| axum::headers::Error::invalid())?
1285            .parse()
1286            .map_err(|_| axum::headers::Error::invalid())?;
1287        Ok(Self(version))
1288    }
1289
1290    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1291        values.extend([self.0.to_string().parse().unwrap()]);
1292    }
1293}
1294
1295pub struct AppVersionHeader(SemanticVersion);
1296impl Header for AppVersionHeader {
1297    fn name() -> &'static HeaderName {
1298        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1299        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1300    }
1301
1302    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1303    where
1304        Self: Sized,
1305        I: Iterator<Item = &'i axum::http::HeaderValue>,
1306    {
1307        let version = values
1308            .next()
1309            .ok_or_else(axum::headers::Error::invalid)?
1310            .to_str()
1311            .map_err(|_| axum::headers::Error::invalid())?
1312            .parse()
1313            .map_err(|_| axum::headers::Error::invalid())?;
1314        Ok(Self(version))
1315    }
1316
1317    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1318        values.extend([self.0.to_string().parse().unwrap()]);
1319    }
1320}
1321
1322pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1323    Router::new()
1324        .route("/rpc", get(handle_websocket_request))
1325        .layer(
1326            ServiceBuilder::new()
1327                .layer(Extension(server.app_state.clone()))
1328                .layer(middleware::from_fn(auth::validate_header)),
1329        )
1330        .route("/metrics", get(handle_metrics))
1331        .layer(Extension(server))
1332}
1333
1334pub async fn handle_websocket_request(
1335    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1336    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1337    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1338    Extension(server): Extension<Arc<Server>>,
1339    Extension(principal): Extension<Principal>,
1340    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1341    ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343    if protocol_version != rpc::PROTOCOL_VERSION {
1344        return (
1345            StatusCode::UPGRADE_REQUIRED,
1346            "client must be upgraded".to_string(),
1347        )
1348            .into_response();
1349    }
1350
1351    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352        return (
1353            StatusCode::UPGRADE_REQUIRED,
1354            "no version header found".to_string(),
1355        )
1356            .into_response();
1357    };
1358
1359    if !version.can_collaborate() {
1360        return (
1361            StatusCode::UPGRADE_REQUIRED,
1362            "client must be upgraded".to_string(),
1363        )
1364            .into_response();
1365    }
1366
1367    let socket_address = socket_address.to_string();
1368    ws.on_upgrade(move |socket| {
1369        let socket = socket
1370            .map_ok(to_tungstenite_message)
1371            .err_into()
1372            .with(|message| async move { to_axum_message(message) });
1373        let connection = Connection::new(Box::pin(socket));
1374        async move {
1375            server
1376                .handle_connection(
1377                    connection,
1378                    socket_address,
1379                    principal,
1380                    version,
1381                    country_code_header.map(|header| header.to_string()),
1382                    None,
1383                    Executor::Production,
1384                )
1385                .await;
1386        }
1387    })
1388}
1389
1390pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1391    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1392    let connections_metric = CONNECTIONS_METRIC
1393        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1394
1395    let connections = server
1396        .connection_pool
1397        .lock()
1398        .connections()
1399        .filter(|connection| !connection.admin)
1400        .count();
1401    connections_metric.set(connections as _);
1402
1403    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1405        register_int_gauge!(
1406            "shared_projects",
1407            "number of open projects with one or more guests"
1408        )
1409        .unwrap()
1410    });
1411
1412    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1413    shared_projects_metric.set(shared_projects as _);
1414
1415    let encoder = prometheus::TextEncoder::new();
1416    let metric_families = prometheus::gather();
1417    let encoded_metrics = encoder
1418        .encode_to_string(&metric_families)
1419        .map_err(|err| anyhow!("{}", err))?;
1420    Ok(encoded_metrics)
1421}
1422
1423#[instrument(err, skip(executor))]
1424async fn connection_lost(
1425    session: Session,
1426    mut teardown: watch::Receiver<bool>,
1427    executor: Executor,
1428) -> Result<()> {
1429    session.peer.disconnect(session.connection_id);
1430    session
1431        .connection_pool()
1432        .await
1433        .remove_connection(session.connection_id)?;
1434
1435    session
1436        .db()
1437        .await
1438        .connection_lost(session.connection_id)
1439        .await
1440        .trace_err();
1441
1442    futures::select_biased! {
1443        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1444            match &session.principal {
1445                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1446                    let session = session.for_user().unwrap();
1447
1448                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1449                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1450                    leave_channel_buffers_for_session(&session)
1451                        .await
1452                        .trace_err();
1453
1454                    if !session
1455                        .connection_pool()
1456                        .await
1457                        .is_user_online(session.user_id())
1458                    {
1459                        let db = session.db().await;
1460                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1461                            room_updated(&room, &session.peer);
1462                        }
1463                    }
1464
1465                    update_user_contacts(session.user_id(), &session).await?;
1466                },
1467            Principal::DevServer(_) => {
1468                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1469            },
1470        }
1471        },
1472        _ = teardown.changed().fuse() => {}
1473    }
1474
1475    Ok(())
1476}
1477
1478/// Acknowledges a ping from a client, used to keep the connection alive.
1479async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1480    response.send(proto::Ack {})?;
1481    Ok(())
1482}
1483
1484/// Creates a new room for calling (outside of channels)
1485async fn create_room(
1486    _request: proto::CreateRoom,
1487    response: Response<proto::CreateRoom>,
1488    session: UserSession,
1489) -> Result<()> {
1490    let live_kit_room = nanoid::nanoid!(30);
1491
1492    let live_kit_connection_info = util::maybe!(async {
1493        let live_kit = session.app_state.live_kit_client.as_ref();
1494        let live_kit = live_kit?;
1495        let user_id = session.user_id().to_string();
1496
1497        let token = live_kit
1498            .room_token(&live_kit_room, &user_id.to_string())
1499            .trace_err()?;
1500
1501        Some(proto::LiveKitConnectionInfo {
1502            server_url: live_kit.url().into(),
1503            token,
1504            can_publish: true,
1505        })
1506    })
1507    .await;
1508
1509    let room = session
1510        .db()
1511        .await
1512        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1513        .await?;
1514
1515    response.send(proto::CreateRoomResponse {
1516        room: Some(room.clone()),
1517        live_kit_connection_info,
1518    })?;
1519
1520    update_user_contacts(session.user_id(), &session).await?;
1521    Ok(())
1522}
1523
1524/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1525async fn join_room(
1526    request: proto::JoinRoom,
1527    response: Response<proto::JoinRoom>,
1528    session: UserSession,
1529) -> Result<()> {
1530    let room_id = RoomId::from_proto(request.id);
1531
1532    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1533
1534    if let Some(channel_id) = channel_id {
1535        return join_channel_internal(channel_id, Box::new(response), session).await;
1536    }
1537
1538    let joined_room = {
1539        let room = session
1540            .db()
1541            .await
1542            .join_room(room_id, session.user_id(), session.connection_id)
1543            .await?;
1544        room_updated(&room.room, &session.peer);
1545        room.into_inner()
1546    };
1547
1548    for connection_id in session
1549        .connection_pool()
1550        .await
1551        .user_connection_ids(session.user_id())
1552    {
1553        session
1554            .peer
1555            .send(
1556                connection_id,
1557                proto::CallCanceled {
1558                    room_id: room_id.to_proto(),
1559                },
1560            )
1561            .trace_err();
1562    }
1563
1564    let live_kit_connection_info =
1565        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1566            if let Some(token) = live_kit
1567                .room_token(
1568                    &joined_room.room.live_kit_room,
1569                    &session.user_id().to_string(),
1570                )
1571                .trace_err()
1572            {
1573                Some(proto::LiveKitConnectionInfo {
1574                    server_url: live_kit.url().into(),
1575                    token,
1576                    can_publish: true,
1577                })
1578            } else {
1579                None
1580            }
1581        } else {
1582            None
1583        };
1584
1585    response.send(proto::JoinRoomResponse {
1586        room: Some(joined_room.room),
1587        channel_id: None,
1588        live_kit_connection_info,
1589    })?;
1590
1591    update_user_contacts(session.user_id(), &session).await?;
1592    Ok(())
1593}
1594
1595/// Rejoin room is used to reconnect to a room after connection errors.
1596async fn rejoin_room(
1597    request: proto::RejoinRoom,
1598    response: Response<proto::RejoinRoom>,
1599    session: UserSession,
1600) -> Result<()> {
1601    let room;
1602    let channel;
1603    {
1604        let mut rejoined_room = session
1605            .db()
1606            .await
1607            .rejoin_room(request, session.user_id(), session.connection_id)
1608            .await?;
1609
1610        response.send(proto::RejoinRoomResponse {
1611            room: Some(rejoined_room.room.clone()),
1612            reshared_projects: rejoined_room
1613                .reshared_projects
1614                .iter()
1615                .map(|project| proto::ResharedProject {
1616                    id: project.id.to_proto(),
1617                    collaborators: project
1618                        .collaborators
1619                        .iter()
1620                        .map(|collaborator| collaborator.to_proto())
1621                        .collect(),
1622                })
1623                .collect(),
1624            rejoined_projects: rejoined_room
1625                .rejoined_projects
1626                .iter()
1627                .map(|rejoined_project| rejoined_project.to_proto())
1628                .collect(),
1629        })?;
1630        room_updated(&rejoined_room.room, &session.peer);
1631
1632        for project in &rejoined_room.reshared_projects {
1633            for collaborator in &project.collaborators {
1634                session
1635                    .peer
1636                    .send(
1637                        collaborator.connection_id,
1638                        proto::UpdateProjectCollaborator {
1639                            project_id: project.id.to_proto(),
1640                            old_peer_id: Some(project.old_connection_id.into()),
1641                            new_peer_id: Some(session.connection_id.into()),
1642                        },
1643                    )
1644                    .trace_err();
1645            }
1646
1647            broadcast(
1648                Some(session.connection_id),
1649                project
1650                    .collaborators
1651                    .iter()
1652                    .map(|collaborator| collaborator.connection_id),
1653                |connection_id| {
1654                    session.peer.forward_send(
1655                        session.connection_id,
1656                        connection_id,
1657                        proto::UpdateProject {
1658                            project_id: project.id.to_proto(),
1659                            worktrees: project.worktrees.clone(),
1660                        },
1661                    )
1662                },
1663            );
1664        }
1665
1666        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1667
1668        let rejoined_room = rejoined_room.into_inner();
1669
1670        room = rejoined_room.room;
1671        channel = rejoined_room.channel;
1672    }
1673
1674    if let Some(channel) = channel {
1675        channel_updated(
1676            &channel,
1677            &room,
1678            &session.peer,
1679            &*session.connection_pool().await,
1680        );
1681    }
1682
1683    update_user_contacts(session.user_id(), &session).await?;
1684    Ok(())
1685}
1686
1687fn notify_rejoined_projects(
1688    rejoined_projects: &mut Vec<RejoinedProject>,
1689    session: &UserSession,
1690) -> Result<()> {
1691    for project in rejoined_projects.iter() {
1692        for collaborator in &project.collaborators {
1693            session
1694                .peer
1695                .send(
1696                    collaborator.connection_id,
1697                    proto::UpdateProjectCollaborator {
1698                        project_id: project.id.to_proto(),
1699                        old_peer_id: Some(project.old_connection_id.into()),
1700                        new_peer_id: Some(session.connection_id.into()),
1701                    },
1702                )
1703                .trace_err();
1704        }
1705    }
1706
1707    for project in rejoined_projects {
1708        for worktree in mem::take(&mut project.worktrees) {
1709            #[cfg(any(test, feature = "test-support"))]
1710            const MAX_CHUNK_SIZE: usize = 2;
1711            #[cfg(not(any(test, feature = "test-support")))]
1712            const MAX_CHUNK_SIZE: usize = 256;
1713
1714            // Stream this worktree's entries.
1715            let message = proto::UpdateWorktree {
1716                project_id: project.id.to_proto(),
1717                worktree_id: worktree.id,
1718                abs_path: worktree.abs_path.clone(),
1719                root_name: worktree.root_name,
1720                updated_entries: worktree.updated_entries,
1721                removed_entries: worktree.removed_entries,
1722                scan_id: worktree.scan_id,
1723                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1724                updated_repositories: worktree.updated_repositories,
1725                removed_repositories: worktree.removed_repositories,
1726            };
1727            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1728                session.peer.send(session.connection_id, update.clone())?;
1729            }
1730
1731            // Stream this worktree's diagnostics.
1732            for summary in worktree.diagnostic_summaries {
1733                session.peer.send(
1734                    session.connection_id,
1735                    proto::UpdateDiagnosticSummary {
1736                        project_id: project.id.to_proto(),
1737                        worktree_id: worktree.id,
1738                        summary: Some(summary),
1739                    },
1740                )?;
1741            }
1742
1743            for settings_file in worktree.settings_files {
1744                session.peer.send(
1745                    session.connection_id,
1746                    proto::UpdateWorktreeSettings {
1747                        project_id: project.id.to_proto(),
1748                        worktree_id: worktree.id,
1749                        path: settings_file.path,
1750                        content: Some(settings_file.content),
1751                    },
1752                )?;
1753            }
1754        }
1755
1756        for language_server in &project.language_servers {
1757            session.peer.send(
1758                session.connection_id,
1759                proto::UpdateLanguageServer {
1760                    project_id: project.id.to_proto(),
1761                    language_server_id: language_server.id,
1762                    variant: Some(
1763                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1764                            proto::LspDiskBasedDiagnosticsUpdated {},
1765                        ),
1766                    ),
1767                },
1768            )?;
1769        }
1770    }
1771    Ok(())
1772}
1773
1774/// leave room disconnects from the room.
1775async fn leave_room(
1776    _: proto::LeaveRoom,
1777    response: Response<proto::LeaveRoom>,
1778    session: UserSession,
1779) -> Result<()> {
1780    leave_room_for_session(&session, session.connection_id).await?;
1781    response.send(proto::Ack {})?;
1782    Ok(())
1783}
1784
1785/// Updates the permissions of someone else in the room.
1786async fn set_room_participant_role(
1787    request: proto::SetRoomParticipantRole,
1788    response: Response<proto::SetRoomParticipantRole>,
1789    session: UserSession,
1790) -> Result<()> {
1791    let user_id = UserId::from_proto(request.user_id);
1792    let role = ChannelRole::from(request.role());
1793
1794    let (live_kit_room, can_publish) = {
1795        let room = session
1796            .db()
1797            .await
1798            .set_room_participant_role(
1799                session.user_id(),
1800                RoomId::from_proto(request.room_id),
1801                user_id,
1802                role,
1803            )
1804            .await?;
1805
1806        let live_kit_room = room.live_kit_room.clone();
1807        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1808        room_updated(&room, &session.peer);
1809        (live_kit_room, can_publish)
1810    };
1811
1812    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1813        live_kit
1814            .update_participant(
1815                live_kit_room.clone(),
1816                request.user_id.to_string(),
1817                live_kit_server::proto::ParticipantPermission {
1818                    can_subscribe: true,
1819                    can_publish,
1820                    can_publish_data: can_publish,
1821                    hidden: false,
1822                    recorder: false,
1823                },
1824            )
1825            .await
1826            .trace_err();
1827    }
1828
1829    response.send(proto::Ack {})?;
1830    Ok(())
1831}
1832
1833/// Call someone else into the current room
1834async fn call(
1835    request: proto::Call,
1836    response: Response<proto::Call>,
1837    session: UserSession,
1838) -> Result<()> {
1839    let room_id = RoomId::from_proto(request.room_id);
1840    let calling_user_id = session.user_id();
1841    let calling_connection_id = session.connection_id;
1842    let called_user_id = UserId::from_proto(request.called_user_id);
1843    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1844    if !session
1845        .db()
1846        .await
1847        .has_contact(calling_user_id, called_user_id)
1848        .await?
1849    {
1850        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1851    }
1852
1853    let incoming_call = {
1854        let (room, incoming_call) = &mut *session
1855            .db()
1856            .await
1857            .call(
1858                room_id,
1859                calling_user_id,
1860                calling_connection_id,
1861                called_user_id,
1862                initial_project_id,
1863            )
1864            .await?;
1865        room_updated(&room, &session.peer);
1866        mem::take(incoming_call)
1867    };
1868    update_user_contacts(called_user_id, &session).await?;
1869
1870    let mut calls = session
1871        .connection_pool()
1872        .await
1873        .user_connection_ids(called_user_id)
1874        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1875        .collect::<FuturesUnordered<_>>();
1876
1877    while let Some(call_response) = calls.next().await {
1878        match call_response.as_ref() {
1879            Ok(_) => {
1880                response.send(proto::Ack {})?;
1881                return Ok(());
1882            }
1883            Err(_) => {
1884                call_response.trace_err();
1885            }
1886        }
1887    }
1888
1889    {
1890        let room = session
1891            .db()
1892            .await
1893            .call_failed(room_id, called_user_id)
1894            .await?;
1895        room_updated(&room, &session.peer);
1896    }
1897    update_user_contacts(called_user_id, &session).await?;
1898
1899    Err(anyhow!("failed to ring user"))?
1900}
1901
1902/// Cancel an outgoing call.
1903async fn cancel_call(
1904    request: proto::CancelCall,
1905    response: Response<proto::CancelCall>,
1906    session: UserSession,
1907) -> Result<()> {
1908    let called_user_id = UserId::from_proto(request.called_user_id);
1909    let room_id = RoomId::from_proto(request.room_id);
1910    {
1911        let room = session
1912            .db()
1913            .await
1914            .cancel_call(room_id, session.connection_id, called_user_id)
1915            .await?;
1916        room_updated(&room, &session.peer);
1917    }
1918
1919    for connection_id in session
1920        .connection_pool()
1921        .await
1922        .user_connection_ids(called_user_id)
1923    {
1924        session
1925            .peer
1926            .send(
1927                connection_id,
1928                proto::CallCanceled {
1929                    room_id: room_id.to_proto(),
1930                },
1931            )
1932            .trace_err();
1933    }
1934    response.send(proto::Ack {})?;
1935
1936    update_user_contacts(called_user_id, &session).await?;
1937    Ok(())
1938}
1939
1940/// Decline an incoming call.
1941async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1942    let room_id = RoomId::from_proto(message.room_id);
1943    {
1944        let room = session
1945            .db()
1946            .await
1947            .decline_call(Some(room_id), session.user_id())
1948            .await?
1949            .ok_or_else(|| anyhow!("failed to decline call"))?;
1950        room_updated(&room, &session.peer);
1951    }
1952
1953    for connection_id in session
1954        .connection_pool()
1955        .await
1956        .user_connection_ids(session.user_id())
1957    {
1958        session
1959            .peer
1960            .send(
1961                connection_id,
1962                proto::CallCanceled {
1963                    room_id: room_id.to_proto(),
1964                },
1965            )
1966            .trace_err();
1967    }
1968    update_user_contacts(session.user_id(), &session).await?;
1969    Ok(())
1970}
1971
1972/// Updates other participants in the room with your current location.
1973async fn update_participant_location(
1974    request: proto::UpdateParticipantLocation,
1975    response: Response<proto::UpdateParticipantLocation>,
1976    session: UserSession,
1977) -> Result<()> {
1978    let room_id = RoomId::from_proto(request.room_id);
1979    let location = request
1980        .location
1981        .ok_or_else(|| anyhow!("invalid location"))?;
1982
1983    let db = session.db().await;
1984    let room = db
1985        .update_room_participant_location(room_id, session.connection_id, location)
1986        .await?;
1987
1988    room_updated(&room, &session.peer);
1989    response.send(proto::Ack {})?;
1990    Ok(())
1991}
1992
1993/// Share a project into the room.
1994async fn share_project(
1995    request: proto::ShareProject,
1996    response: Response<proto::ShareProject>,
1997    session: UserSession,
1998) -> Result<()> {
1999    let (project_id, room) = &*session
2000        .db()
2001        .await
2002        .share_project(
2003            RoomId::from_proto(request.room_id),
2004            session.connection_id,
2005            &request.worktrees,
2006            request
2007                .dev_server_project_id
2008                .map(|id| DevServerProjectId::from_proto(id)),
2009        )
2010        .await?;
2011    response.send(proto::ShareProjectResponse {
2012        project_id: project_id.to_proto(),
2013    })?;
2014    room_updated(&room, &session.peer);
2015
2016    Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021    let project_id = ProjectId::from_proto(message.project_id);
2022    unshare_project_internal(
2023        project_id,
2024        session.connection_id,
2025        session.user_id(),
2026        &session,
2027    )
2028    .await
2029}
2030
2031async fn unshare_project_internal(
2032    project_id: ProjectId,
2033    connection_id: ConnectionId,
2034    user_id: Option<UserId>,
2035    session: &Session,
2036) -> Result<()> {
2037    let delete = {
2038        let room_guard = session
2039            .db()
2040            .await
2041            .unshare_project(project_id, connection_id, user_id)
2042            .await?;
2043
2044        let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046        let message = proto::UnshareProject {
2047            project_id: project_id.to_proto(),
2048        };
2049
2050        broadcast(
2051            Some(connection_id),
2052            guest_connection_ids.iter().copied(),
2053            |conn_id| session.peer.send(conn_id, message.clone()),
2054        );
2055        if let Some(room) = room {
2056            room_updated(room, &session.peer);
2057        }
2058
2059        *delete
2060    };
2061
2062    if delete {
2063        let db = session.db().await;
2064        db.delete_project(project_id).await?;
2065    }
2066
2067    Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072    request: proto::ShareDevServerProject,
2073    response: Response<proto::ShareDevServerProject>,
2074    session: DevServerSession,
2075) -> Result<()> {
2076    let (dev_server_project, user_id, status) = session
2077        .db()
2078        .await
2079        .share_dev_server_project(
2080            DevServerProjectId::from_proto(request.dev_server_project_id),
2081            session.dev_server_id(),
2082            session.connection_id,
2083            &request.worktrees,
2084        )
2085        .await?;
2086    let Some(project_id) = dev_server_project.project_id else {
2087        return Err(anyhow!("failed to share remote project"))?;
2088    };
2089
2090    send_dev_server_projects_update(user_id, status, &session).await;
2091
2092    response.send(proto::ShareProjectResponse { project_id })?;
2093
2094    Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099    request: proto::JoinProject,
2100    response: Response<proto::JoinProject>,
2101    session: UserSession,
2102) -> Result<()> {
2103    let project_id = ProjectId::from_proto(request.project_id);
2104
2105    tracing::info!(%project_id, "join project");
2106
2107    let db = session.db().await;
2108    let (project, replica_id) = &mut *db
2109        .join_project(project_id, session.connection_id, session.user_id())
2110        .await?;
2111    drop(db);
2112    tracing::info!(%project_id, "join remote project");
2113    join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121        Response::<proto::JoinProject>::send(self, result)
2122    }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126        Response::<proto::JoinHostedProject>::send(self, result)
2127    }
2128}
2129
2130fn join_project_internal(
2131    response: impl JoinProjectInternalResponse,
2132    session: UserSession,
2133    project: &mut Project,
2134    replica_id: &ReplicaId,
2135) -> Result<()> {
2136    let collaborators = project
2137        .collaborators
2138        .iter()
2139        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140        .map(|collaborator| collaborator.to_proto())
2141        .collect::<Vec<_>>();
2142    let project_id = project.id;
2143    let guest_user_id = session.user_id();
2144
2145    let worktrees = project
2146        .worktrees
2147        .iter()
2148        .map(|(id, worktree)| proto::WorktreeMetadata {
2149            id: *id,
2150            root_name: worktree.root_name.clone(),
2151            visible: worktree.visible,
2152            abs_path: worktree.abs_path.clone(),
2153        })
2154        .collect::<Vec<_>>();
2155
2156    let add_project_collaborator = proto::AddProjectCollaborator {
2157        project_id: project_id.to_proto(),
2158        collaborator: Some(proto::Collaborator {
2159            peer_id: Some(session.connection_id.into()),
2160            replica_id: replica_id.0 as u32,
2161            user_id: guest_user_id.to_proto(),
2162        }),
2163    };
2164
2165    for collaborator in &collaborators {
2166        session
2167            .peer
2168            .send(
2169                collaborator.peer_id.unwrap().into(),
2170                add_project_collaborator.clone(),
2171            )
2172            .trace_err();
2173    }
2174
2175    // First, we send the metadata associated with each worktree.
2176    response.send(proto::JoinProjectResponse {
2177        project_id: project.id.0 as u64,
2178        worktrees: worktrees.clone(),
2179        replica_id: replica_id.0 as u32,
2180        collaborators: collaborators.clone(),
2181        language_servers: project.language_servers.clone(),
2182        role: project.role.into(),
2183        dev_server_project_id: project
2184            .dev_server_project_id
2185            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186    })?;
2187
2188    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189        #[cfg(any(test, feature = "test-support"))]
2190        const MAX_CHUNK_SIZE: usize = 2;
2191        #[cfg(not(any(test, feature = "test-support")))]
2192        const MAX_CHUNK_SIZE: usize = 256;
2193
2194        // Stream this worktree's entries.
2195        let message = proto::UpdateWorktree {
2196            project_id: project_id.to_proto(),
2197            worktree_id,
2198            abs_path: worktree.abs_path.clone(),
2199            root_name: worktree.root_name,
2200            updated_entries: worktree.entries,
2201            removed_entries: Default::default(),
2202            scan_id: worktree.scan_id,
2203            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204            updated_repositories: worktree.repository_entries.into_values().collect(),
2205            removed_repositories: Default::default(),
2206        };
2207        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208            session.peer.send(session.connection_id, update.clone())?;
2209        }
2210
2211        // Stream this worktree's diagnostics.
2212        for summary in worktree.diagnostic_summaries {
2213            session.peer.send(
2214                session.connection_id,
2215                proto::UpdateDiagnosticSummary {
2216                    project_id: project_id.to_proto(),
2217                    worktree_id: worktree.id,
2218                    summary: Some(summary),
2219                },
2220            )?;
2221        }
2222
2223        for settings_file in worktree.settings_files {
2224            session.peer.send(
2225                session.connection_id,
2226                proto::UpdateWorktreeSettings {
2227                    project_id: project_id.to_proto(),
2228                    worktree_id: worktree.id,
2229                    path: settings_file.path,
2230                    content: Some(settings_file.content),
2231                },
2232            )?;
2233        }
2234    }
2235
2236    for language_server in &project.language_servers {
2237        session.peer.send(
2238            session.connection_id,
2239            proto::UpdateLanguageServer {
2240                project_id: project_id.to_proto(),
2241                language_server_id: language_server.id,
2242                variant: Some(
2243                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2244                        proto::LspDiskBasedDiagnosticsUpdated {},
2245                    ),
2246                ),
2247            },
2248        )?;
2249    }
2250
2251    Ok(())
2252}
2253
2254/// Leave someone elses shared project.
2255async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2256    let sender_id = session.connection_id;
2257    let project_id = ProjectId::from_proto(request.project_id);
2258    let db = session.db().await;
2259    if db.is_hosted_project(project_id).await? {
2260        let project = db.leave_hosted_project(project_id, sender_id).await?;
2261        project_left(&project, &session);
2262        return Ok(());
2263    }
2264
2265    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2266    tracing::info!(
2267        %project_id,
2268        "leave project"
2269    );
2270
2271    project_left(&project, &session);
2272    if let Some(room) = room {
2273        room_updated(&room, &session.peer);
2274    }
2275
2276    Ok(())
2277}
2278
2279async fn join_hosted_project(
2280    request: proto::JoinHostedProject,
2281    response: Response<proto::JoinHostedProject>,
2282    session: UserSession,
2283) -> Result<()> {
2284    let (mut project, replica_id) = session
2285        .db()
2286        .await
2287        .join_hosted_project(
2288            ProjectId(request.project_id as i32),
2289            session.user_id(),
2290            session.connection_id,
2291        )
2292        .await?;
2293
2294    join_project_internal(response, session, &mut project, &replica_id)
2295}
2296
2297async fn list_remote_directory(
2298    request: proto::ListRemoteDirectory,
2299    response: Response<proto::ListRemoteDirectory>,
2300    session: UserSession,
2301) -> Result<()> {
2302    let dev_server_id = DevServerId(request.dev_server_id as i32);
2303    let dev_server_connection_id = session
2304        .connection_pool()
2305        .await
2306        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2307
2308    session
2309        .db()
2310        .await
2311        .get_dev_server_for_user(dev_server_id, session.user_id())
2312        .await?;
2313
2314    response.send(
2315        session
2316            .peer
2317            .forward_request(session.connection_id, dev_server_connection_id, request)
2318            .await?,
2319    )?;
2320    Ok(())
2321}
2322
2323async fn update_dev_server_project(
2324    request: proto::UpdateDevServerProject,
2325    response: Response<proto::UpdateDevServerProject>,
2326    session: UserSession,
2327) -> Result<()> {
2328    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2329
2330    let (dev_server_project, update) = session
2331        .db()
2332        .await
2333        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2334        .await?;
2335
2336    let projects = session
2337        .db()
2338        .await
2339        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2340        .await?;
2341
2342    let dev_server_connection_id = session
2343        .connection_pool()
2344        .await
2345        .dev_server_connection_id_supporting(
2346            dev_server_project.dev_server_id,
2347            ZedVersion::with_list_directory(),
2348        )?;
2349
2350    session.peer.send(
2351        dev_server_connection_id,
2352        proto::DevServerInstructions { projects },
2353    )?;
2354
2355    send_dev_server_projects_update(session.user_id(), update, &session).await;
2356
2357    response.send(proto::Ack {})
2358}
2359
2360async fn create_dev_server_project(
2361    request: proto::CreateDevServerProject,
2362    response: Response<proto::CreateDevServerProject>,
2363    session: UserSession,
2364) -> Result<()> {
2365    let dev_server_id = DevServerId(request.dev_server_id as i32);
2366    let dev_server_connection_id = session
2367        .connection_pool()
2368        .await
2369        .dev_server_connection_id(dev_server_id);
2370    let Some(dev_server_connection_id) = dev_server_connection_id else {
2371        Err(ErrorCode::DevServerOffline
2372            .message("Cannot create a remote project when the dev server is offline".to_string())
2373            .anyhow())?
2374    };
2375
2376    let path = request.path.clone();
2377    //Check that the path exists on the dev server
2378    session
2379        .peer
2380        .forward_request(
2381            session.connection_id,
2382            dev_server_connection_id,
2383            proto::ValidateDevServerProjectRequest { path: path.clone() },
2384        )
2385        .await?;
2386
2387    let (dev_server_project, update) = session
2388        .db()
2389        .await
2390        .create_dev_server_project(
2391            DevServerId(request.dev_server_id as i32),
2392            &request.path,
2393            session.user_id(),
2394        )
2395        .await?;
2396
2397    let projects = session
2398        .db()
2399        .await
2400        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2401        .await?;
2402
2403    session.peer.send(
2404        dev_server_connection_id,
2405        proto::DevServerInstructions { projects },
2406    )?;
2407
2408    send_dev_server_projects_update(session.user_id(), update, &session).await;
2409
2410    response.send(proto::CreateDevServerProjectResponse {
2411        dev_server_project: Some(dev_server_project.to_proto(None)),
2412    })?;
2413    Ok(())
2414}
2415
2416async fn create_dev_server(
2417    request: proto::CreateDevServer,
2418    response: Response<proto::CreateDevServer>,
2419    session: UserSession,
2420) -> Result<()> {
2421    let access_token = auth::random_token();
2422    let hashed_access_token = auth::hash_access_token(&access_token);
2423
2424    if request.name.is_empty() {
2425        return Err(proto::ErrorCode::Forbidden
2426            .message("Dev server name cannot be empty".to_string())
2427            .anyhow())?;
2428    }
2429
2430    let (dev_server, status) = session
2431        .db()
2432        .await
2433        .create_dev_server(
2434            &request.name,
2435            request.ssh_connection_string.as_deref(),
2436            &hashed_access_token,
2437            session.user_id(),
2438        )
2439        .await?;
2440
2441    send_dev_server_projects_update(session.user_id(), status, &session).await;
2442
2443    response.send(proto::CreateDevServerResponse {
2444        dev_server_id: dev_server.id.0 as u64,
2445        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2446        name: request.name,
2447    })?;
2448    Ok(())
2449}
2450
2451async fn regenerate_dev_server_token(
2452    request: proto::RegenerateDevServerToken,
2453    response: Response<proto::RegenerateDevServerToken>,
2454    session: UserSession,
2455) -> Result<()> {
2456    let dev_server_id = DevServerId(request.dev_server_id as i32);
2457    let access_token = auth::random_token();
2458    let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460    let connection_id = session
2461        .connection_pool()
2462        .await
2463        .dev_server_connection_id(dev_server_id);
2464    if let Some(connection_id) = connection_id {
2465        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2466        session.peer.send(
2467            connection_id,
2468            proto::ShutdownDevServer {
2469                reason: Some("dev server token was regenerated".to_string()),
2470            },
2471        )?;
2472        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2473    }
2474
2475    let status = session
2476        .db()
2477        .await
2478        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2479        .await?;
2480
2481    send_dev_server_projects_update(session.user_id(), status, &session).await;
2482
2483    response.send(proto::RegenerateDevServerTokenResponse {
2484        dev_server_id: dev_server_id.to_proto(),
2485        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2486    })?;
2487    Ok(())
2488}
2489
2490async fn rename_dev_server(
2491    request: proto::RenameDevServer,
2492    response: Response<proto::RenameDevServer>,
2493    session: UserSession,
2494) -> Result<()> {
2495    if request.name.trim().is_empty() {
2496        return Err(proto::ErrorCode::Forbidden
2497            .message("Dev server name cannot be empty".to_string())
2498            .anyhow())?;
2499    }
2500
2501    let dev_server_id = DevServerId(request.dev_server_id as i32);
2502    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2503    if dev_server.user_id != session.user_id() {
2504        return Err(anyhow!(ErrorCode::Forbidden))?;
2505    }
2506
2507    let status = session
2508        .db()
2509        .await
2510        .rename_dev_server(
2511            dev_server_id,
2512            &request.name,
2513            request.ssh_connection_string.as_deref(),
2514            session.user_id(),
2515        )
2516        .await?;
2517
2518    send_dev_server_projects_update(session.user_id(), status, &session).await;
2519
2520    response.send(proto::Ack {})?;
2521    Ok(())
2522}
2523
2524async fn delete_dev_server(
2525    request: proto::DeleteDevServer,
2526    response: Response<proto::DeleteDevServer>,
2527    session: UserSession,
2528) -> Result<()> {
2529    let dev_server_id = DevServerId(request.dev_server_id as i32);
2530    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2531    if dev_server.user_id != session.user_id() {
2532        return Err(anyhow!(ErrorCode::Forbidden))?;
2533    }
2534
2535    let connection_id = session
2536        .connection_pool()
2537        .await
2538        .dev_server_connection_id(dev_server_id);
2539    if let Some(connection_id) = connection_id {
2540        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2541        session.peer.send(
2542            connection_id,
2543            proto::ShutdownDevServer {
2544                reason: Some("dev server was deleted".to_string()),
2545            },
2546        )?;
2547        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2548    }
2549
2550    let status = session
2551        .db()
2552        .await
2553        .delete_dev_server(dev_server_id, session.user_id())
2554        .await?;
2555
2556    send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558    response.send(proto::Ack {})?;
2559    Ok(())
2560}
2561
2562async fn delete_dev_server_project(
2563    request: proto::DeleteDevServerProject,
2564    response: Response<proto::DeleteDevServerProject>,
2565    session: UserSession,
2566) -> Result<()> {
2567    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2568    let dev_server_project = session
2569        .db()
2570        .await
2571        .get_dev_server_project(dev_server_project_id)
2572        .await?;
2573
2574    let dev_server = session
2575        .db()
2576        .await
2577        .get_dev_server(dev_server_project.dev_server_id)
2578        .await?;
2579    if dev_server.user_id != session.user_id() {
2580        return Err(anyhow!(ErrorCode::Forbidden))?;
2581    }
2582
2583    let dev_server_connection_id = session
2584        .connection_pool()
2585        .await
2586        .dev_server_connection_id(dev_server.id);
2587
2588    if let Some(dev_server_connection_id) = dev_server_connection_id {
2589        let project = session
2590            .db()
2591            .await
2592            .find_dev_server_project(dev_server_project_id)
2593            .await;
2594        if let Ok(project) = project {
2595            unshare_project_internal(
2596                project.id,
2597                dev_server_connection_id,
2598                Some(session.user_id()),
2599                &session,
2600            )
2601            .await?;
2602        }
2603    }
2604
2605    let (projects, status) = session
2606        .db()
2607        .await
2608        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2609        .await?;
2610
2611    if let Some(dev_server_connection_id) = dev_server_connection_id {
2612        session.peer.send(
2613            dev_server_connection_id,
2614            proto::DevServerInstructions { projects },
2615        )?;
2616    }
2617
2618    send_dev_server_projects_update(session.user_id(), status, &session).await;
2619
2620    response.send(proto::Ack {})?;
2621    Ok(())
2622}
2623
2624async fn rejoin_dev_server_projects(
2625    request: proto::RejoinRemoteProjects,
2626    response: Response<proto::RejoinRemoteProjects>,
2627    session: UserSession,
2628) -> Result<()> {
2629    let mut rejoined_projects = {
2630        let db = session.db().await;
2631        db.rejoin_dev_server_projects(
2632            &request.rejoined_projects,
2633            session.user_id(),
2634            session.0.connection_id,
2635        )
2636        .await?
2637    };
2638    response.send(proto::RejoinRemoteProjectsResponse {
2639        rejoined_projects: rejoined_projects
2640            .iter()
2641            .map(|project| project.to_proto())
2642            .collect(),
2643    })?;
2644    notify_rejoined_projects(&mut rejoined_projects, &session)
2645}
2646
2647async fn reconnect_dev_server(
2648    request: proto::ReconnectDevServer,
2649    response: Response<proto::ReconnectDevServer>,
2650    session: DevServerSession,
2651) -> Result<()> {
2652    let reshared_projects = {
2653        let db = session.db().await;
2654        db.reshare_dev_server_projects(
2655            &request.reshared_projects,
2656            session.dev_server_id(),
2657            session.0.connection_id,
2658        )
2659        .await?
2660    };
2661
2662    for project in &reshared_projects {
2663        for collaborator in &project.collaborators {
2664            session
2665                .peer
2666                .send(
2667                    collaborator.connection_id,
2668                    proto::UpdateProjectCollaborator {
2669                        project_id: project.id.to_proto(),
2670                        old_peer_id: Some(project.old_connection_id.into()),
2671                        new_peer_id: Some(session.connection_id.into()),
2672                    },
2673                )
2674                .trace_err();
2675        }
2676
2677        broadcast(
2678            Some(session.connection_id),
2679            project
2680                .collaborators
2681                .iter()
2682                .map(|collaborator| collaborator.connection_id),
2683            |connection_id| {
2684                session.peer.forward_send(
2685                    session.connection_id,
2686                    connection_id,
2687                    proto::UpdateProject {
2688                        project_id: project.id.to_proto(),
2689                        worktrees: project.worktrees.clone(),
2690                    },
2691                )
2692            },
2693        );
2694    }
2695
2696    response.send(proto::ReconnectDevServerResponse {
2697        reshared_projects: reshared_projects
2698            .iter()
2699            .map(|project| proto::ResharedProject {
2700                id: project.id.to_proto(),
2701                collaborators: project
2702                    .collaborators
2703                    .iter()
2704                    .map(|collaborator| collaborator.to_proto())
2705                    .collect(),
2706            })
2707            .collect(),
2708    })?;
2709
2710    Ok(())
2711}
2712
2713async fn shutdown_dev_server(
2714    _: proto::ShutdownDevServer,
2715    response: Response<proto::ShutdownDevServer>,
2716    session: DevServerSession,
2717) -> Result<()> {
2718    response.send(proto::Ack {})?;
2719    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2720    remove_dev_server_connection(session.dev_server_id(), &session).await
2721}
2722
2723async fn shutdown_dev_server_internal(
2724    dev_server_id: DevServerId,
2725    connection_id: ConnectionId,
2726    session: &Session,
2727) -> Result<()> {
2728    let (dev_server_projects, dev_server) = {
2729        let db = session.db().await;
2730        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2731        let dev_server = db.get_dev_server(dev_server_id).await?;
2732        (dev_server_projects, dev_server)
2733    };
2734
2735    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2736        unshare_project_internal(
2737            ProjectId::from_proto(project_id),
2738            connection_id,
2739            None,
2740            session,
2741        )
2742        .await?;
2743    }
2744
2745    session
2746        .connection_pool()
2747        .await
2748        .set_dev_server_offline(dev_server_id);
2749
2750    let status = session
2751        .db()
2752        .await
2753        .dev_server_projects_update(dev_server.user_id)
2754        .await?;
2755    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2756
2757    Ok(())
2758}
2759
2760async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2761    let dev_server_connection = session
2762        .connection_pool()
2763        .await
2764        .dev_server_connection_id(dev_server_id);
2765
2766    if let Some(dev_server_connection) = dev_server_connection {
2767        session
2768            .connection_pool()
2769            .await
2770            .remove_connection(dev_server_connection)?;
2771    }
2772    Ok(())
2773}
2774
2775/// Updates other participants with changes to the project
2776async fn update_project(
2777    request: proto::UpdateProject,
2778    response: Response<proto::UpdateProject>,
2779    session: Session,
2780) -> Result<()> {
2781    let project_id = ProjectId::from_proto(request.project_id);
2782    let (room, guest_connection_ids) = &*session
2783        .db()
2784        .await
2785        .update_project(project_id, session.connection_id, &request.worktrees)
2786        .await?;
2787    broadcast(
2788        Some(session.connection_id),
2789        guest_connection_ids.iter().copied(),
2790        |connection_id| {
2791            session
2792                .peer
2793                .forward_send(session.connection_id, connection_id, request.clone())
2794        },
2795    );
2796    if let Some(room) = room {
2797        room_updated(&room, &session.peer);
2798    }
2799    response.send(proto::Ack {})?;
2800
2801    Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree
2805async fn update_worktree(
2806    request: proto::UpdateWorktree,
2807    response: Response<proto::UpdateWorktree>,
2808    session: Session,
2809) -> Result<()> {
2810    let guest_connection_ids = session
2811        .db()
2812        .await
2813        .update_worktree(&request, session.connection_id)
2814        .await?;
2815
2816    broadcast(
2817        Some(session.connection_id),
2818        guest_connection_ids.iter().copied(),
2819        |connection_id| {
2820            session
2821                .peer
2822                .forward_send(session.connection_id, connection_id, request.clone())
2823        },
2824    );
2825    response.send(proto::Ack {})?;
2826    Ok(())
2827}
2828
2829/// Updates other participants with changes to the diagnostics
2830async fn update_diagnostic_summary(
2831    message: proto::UpdateDiagnosticSummary,
2832    session: Session,
2833) -> Result<()> {
2834    let guest_connection_ids = session
2835        .db()
2836        .await
2837        .update_diagnostic_summary(&message, session.connection_id)
2838        .await?;
2839
2840    broadcast(
2841        Some(session.connection_id),
2842        guest_connection_ids.iter().copied(),
2843        |connection_id| {
2844            session
2845                .peer
2846                .forward_send(session.connection_id, connection_id, message.clone())
2847        },
2848    );
2849
2850    Ok(())
2851}
2852
2853/// Updates other participants with changes to the worktree settings
2854async fn update_worktree_settings(
2855    message: proto::UpdateWorktreeSettings,
2856    session: Session,
2857) -> Result<()> {
2858    let guest_connection_ids = session
2859        .db()
2860        .await
2861        .update_worktree_settings(&message, session.connection_id)
2862        .await?;
2863
2864    broadcast(
2865        Some(session.connection_id),
2866        guest_connection_ids.iter().copied(),
2867        |connection_id| {
2868            session
2869                .peer
2870                .forward_send(session.connection_id, connection_id, message.clone())
2871        },
2872    );
2873
2874    Ok(())
2875}
2876
2877/// Notify other participants that a  language server has started.
2878async fn start_language_server(
2879    request: proto::StartLanguageServer,
2880    session: Session,
2881) -> Result<()> {
2882    let guest_connection_ids = session
2883        .db()
2884        .await
2885        .start_language_server(&request, session.connection_id)
2886        .await?;
2887
2888    broadcast(
2889        Some(session.connection_id),
2890        guest_connection_ids.iter().copied(),
2891        |connection_id| {
2892            session
2893                .peer
2894                .forward_send(session.connection_id, connection_id, request.clone())
2895        },
2896    );
2897    Ok(())
2898}
2899
2900/// Notify other participants that a language server has changed.
2901async fn update_language_server(
2902    request: proto::UpdateLanguageServer,
2903    session: Session,
2904) -> Result<()> {
2905    let project_id = ProjectId::from_proto(request.project_id);
2906    let project_connection_ids = session
2907        .db()
2908        .await
2909        .project_connection_ids(project_id, session.connection_id, true)
2910        .await?;
2911    broadcast(
2912        Some(session.connection_id),
2913        project_connection_ids.iter().copied(),
2914        |connection_id| {
2915            session
2916                .peer
2917                .forward_send(session.connection_id, connection_id, request.clone())
2918        },
2919    );
2920    Ok(())
2921}
2922
2923/// forward a project request to the host. These requests should be read only
2924/// as guests are allowed to send them.
2925async fn forward_read_only_project_request<T>(
2926    request: T,
2927    response: Response<T>,
2928    session: UserSession,
2929) -> Result<()>
2930where
2931    T: EntityMessage + RequestMessage,
2932{
2933    let project_id = ProjectId::from_proto(request.remote_entity_id());
2934    let host_connection_id = session
2935        .db()
2936        .await
2937        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2938        .await?;
2939    let payload = session
2940        .peer
2941        .forward_request(session.connection_id, host_connection_id, request)
2942        .await?;
2943    response.send(payload)?;
2944    Ok(())
2945}
2946
2947async fn forward_find_search_candidates_request(
2948    request: proto::FindSearchCandidates,
2949    response: Response<proto::FindSearchCandidates>,
2950    session: UserSession,
2951) -> Result<()> {
2952    let project_id = ProjectId::from_proto(request.remote_entity_id());
2953    let host_connection_id = session
2954        .db()
2955        .await
2956        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2957        .await?;
2958
2959    let host_version = session
2960        .connection_pool()
2961        .await
2962        .connection(host_connection_id)
2963        .map(|c| c.zed_version);
2964
2965    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2966    {
2967        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2968        let search = proto::SearchProject {
2969            project_id: project_id.to_proto(),
2970            query: query.query,
2971            regex: query.regex,
2972            whole_word: query.whole_word,
2973            case_sensitive: query.case_sensitive,
2974            files_to_include: query.files_to_include,
2975            files_to_exclude: query.files_to_exclude,
2976            include_ignored: query.include_ignored,
2977        };
2978
2979        let payload = session
2980            .peer
2981            .forward_request(session.connection_id, host_connection_id, search)
2982            .await?;
2983        return response.send(proto::FindSearchCandidatesResponse {
2984            buffer_ids: payload
2985                .locations
2986                .into_iter()
2987                .map(|loc| loc.buffer_id)
2988                .collect(),
2989        });
2990    }
2991
2992    let payload = session
2993        .peer
2994        .forward_request(session.connection_id, host_connection_id, request)
2995        .await?;
2996    response.send(payload)?;
2997    Ok(())
2998}
2999
3000/// forward a project request to the dev server. Only allowed
3001/// if it's your dev server.
3002async fn forward_project_request_for_owner<T>(
3003    request: T,
3004    response: Response<T>,
3005    session: UserSession,
3006) -> Result<()>
3007where
3008    T: EntityMessage + RequestMessage,
3009{
3010    let project_id = ProjectId::from_proto(request.remote_entity_id());
3011
3012    let host_connection_id = session
3013        .db()
3014        .await
3015        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3016        .await?;
3017    let payload = session
3018        .peer
3019        .forward_request(session.connection_id, host_connection_id, request)
3020        .await?;
3021    response.send(payload)?;
3022    Ok(())
3023}
3024
3025/// forward a project request to the host. These requests are disallowed
3026/// for guests.
3027async fn forward_mutating_project_request<T>(
3028    request: T,
3029    response: Response<T>,
3030    session: UserSession,
3031) -> Result<()>
3032where
3033    T: EntityMessage + RequestMessage,
3034{
3035    let project_id = ProjectId::from_proto(request.remote_entity_id());
3036
3037    let host_connection_id = session
3038        .db()
3039        .await
3040        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3041        .await?;
3042    let payload = session
3043        .peer
3044        .forward_request(session.connection_id, host_connection_id, request)
3045        .await?;
3046    response.send(payload)?;
3047    Ok(())
3048}
3049
3050/// forward a project request to the host. These requests are disallowed
3051/// for guests.
3052async fn forward_versioned_mutating_project_request<T>(
3053    request: T,
3054    response: Response<T>,
3055    session: UserSession,
3056) -> Result<()>
3057where
3058    T: EntityMessage + RequestMessage + VersionedMessage,
3059{
3060    let project_id = ProjectId::from_proto(request.remote_entity_id());
3061
3062    let host_connection_id = session
3063        .db()
3064        .await
3065        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3066        .await?;
3067    if let Some(host_version) = session
3068        .connection_pool()
3069        .await
3070        .connection(host_connection_id)
3071        .map(|c| c.zed_version)
3072    {
3073        if let Some(min_required_version) = request.required_host_version() {
3074            if min_required_version > host_version {
3075                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3076                    .with_tag("required", &min_required_version.to_string())))?;
3077            }
3078        }
3079    }
3080
3081    let payload = session
3082        .peer
3083        .forward_request(session.connection_id, host_connection_id, request)
3084        .await?;
3085    response.send(payload)?;
3086    Ok(())
3087}
3088
3089/// Notify other participants that a new buffer has been created
3090async fn create_buffer_for_peer(
3091    request: proto::CreateBufferForPeer,
3092    session: Session,
3093) -> Result<()> {
3094    session
3095        .db()
3096        .await
3097        .check_user_is_project_host(
3098            ProjectId::from_proto(request.project_id),
3099            session.connection_id,
3100        )
3101        .await?;
3102    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3103    session
3104        .peer
3105        .forward_send(session.connection_id, peer_id.into(), request)?;
3106    Ok(())
3107}
3108
3109/// Notify other participants that a buffer has been updated. This is
3110/// allowed for guests as long as the update is limited to selections.
3111async fn update_buffer(
3112    request: proto::UpdateBuffer,
3113    response: Response<proto::UpdateBuffer>,
3114    session: Session,
3115) -> Result<()> {
3116    let project_id = ProjectId::from_proto(request.project_id);
3117    let mut capability = Capability::ReadOnly;
3118
3119    for op in request.operations.iter() {
3120        match op.variant {
3121            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3122            Some(_) => capability = Capability::ReadWrite,
3123        }
3124    }
3125
3126    let host = {
3127        let guard = session
3128            .db()
3129            .await
3130            .connections_for_buffer_update(
3131                project_id,
3132                session.principal_id(),
3133                session.connection_id,
3134                capability,
3135            )
3136            .await?;
3137
3138        let (host, guests) = &*guard;
3139
3140        broadcast(
3141            Some(session.connection_id),
3142            guests.clone(),
3143            |connection_id| {
3144                session
3145                    .peer
3146                    .forward_send(session.connection_id, connection_id, request.clone())
3147            },
3148        );
3149
3150        *host
3151    };
3152
3153    if host != session.connection_id {
3154        session
3155            .peer
3156            .forward_request(session.connection_id, host, request.clone())
3157            .await?;
3158    }
3159
3160    response.send(proto::Ack {})?;
3161    Ok(())
3162}
3163
3164async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3165    let project_id = ProjectId::from_proto(message.project_id);
3166
3167    let operation = message.operation.as_ref().context("invalid operation")?;
3168    let capability = match operation.variant.as_ref() {
3169        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3170            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3171                match buffer_op.variant {
3172                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3173                        Capability::ReadOnly
3174                    }
3175                    _ => Capability::ReadWrite,
3176                }
3177            } else {
3178                Capability::ReadWrite
3179            }
3180        }
3181        Some(_) => Capability::ReadWrite,
3182        None => Capability::ReadOnly,
3183    };
3184
3185    let guard = session
3186        .db()
3187        .await
3188        .connections_for_buffer_update(
3189            project_id,
3190            session.principal_id(),
3191            session.connection_id,
3192            capability,
3193        )
3194        .await?;
3195
3196    let (host, guests) = &*guard;
3197
3198    broadcast(
3199        Some(session.connection_id),
3200        guests.iter().chain([host]).copied(),
3201        |connection_id| {
3202            session
3203                .peer
3204                .forward_send(session.connection_id, connection_id, message.clone())
3205        },
3206    );
3207
3208    Ok(())
3209}
3210
3211/// Notify other participants that a project has been updated.
3212async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3213    request: T,
3214    session: Session,
3215) -> Result<()> {
3216    let project_id = ProjectId::from_proto(request.remote_entity_id());
3217    let project_connection_ids = session
3218        .db()
3219        .await
3220        .project_connection_ids(project_id, session.connection_id, false)
3221        .await?;
3222
3223    broadcast(
3224        Some(session.connection_id),
3225        project_connection_ids.iter().copied(),
3226        |connection_id| {
3227            session
3228                .peer
3229                .forward_send(session.connection_id, connection_id, request.clone())
3230        },
3231    );
3232    Ok(())
3233}
3234
3235/// Start following another user in a call.
3236async fn follow(
3237    request: proto::Follow,
3238    response: Response<proto::Follow>,
3239    session: UserSession,
3240) -> Result<()> {
3241    let room_id = RoomId::from_proto(request.room_id);
3242    let project_id = request.project_id.map(ProjectId::from_proto);
3243    let leader_id = request
3244        .leader_id
3245        .ok_or_else(|| anyhow!("invalid leader id"))?
3246        .into();
3247    let follower_id = session.connection_id;
3248
3249    session
3250        .db()
3251        .await
3252        .check_room_participants(room_id, leader_id, session.connection_id)
3253        .await?;
3254
3255    let response_payload = session
3256        .peer
3257        .forward_request(session.connection_id, leader_id, request)
3258        .await?;
3259    response.send(response_payload)?;
3260
3261    if let Some(project_id) = project_id {
3262        let room = session
3263            .db()
3264            .await
3265            .follow(room_id, project_id, leader_id, follower_id)
3266            .await?;
3267        room_updated(&room, &session.peer);
3268    }
3269
3270    Ok(())
3271}
3272
3273/// Stop following another user in a call.
3274async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3275    let room_id = RoomId::from_proto(request.room_id);
3276    let project_id = request.project_id.map(ProjectId::from_proto);
3277    let leader_id = request
3278        .leader_id
3279        .ok_or_else(|| anyhow!("invalid leader id"))?
3280        .into();
3281    let follower_id = session.connection_id;
3282
3283    session
3284        .db()
3285        .await
3286        .check_room_participants(room_id, leader_id, session.connection_id)
3287        .await?;
3288
3289    session
3290        .peer
3291        .forward_send(session.connection_id, leader_id, request)?;
3292
3293    if let Some(project_id) = project_id {
3294        let room = session
3295            .db()
3296            .await
3297            .unfollow(room_id, project_id, leader_id, follower_id)
3298            .await?;
3299        room_updated(&room, &session.peer);
3300    }
3301
3302    Ok(())
3303}
3304
3305/// Notify everyone following you of your current location.
3306async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3307    let room_id = RoomId::from_proto(request.room_id);
3308    let database = session.db.lock().await;
3309
3310    let connection_ids = if let Some(project_id) = request.project_id {
3311        let project_id = ProjectId::from_proto(project_id);
3312        database
3313            .project_connection_ids(project_id, session.connection_id, true)
3314            .await?
3315    } else {
3316        database
3317            .room_connection_ids(room_id, session.connection_id)
3318            .await?
3319    };
3320
3321    // For now, don't send view update messages back to that view's current leader.
3322    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3323        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3324        _ => None,
3325    });
3326
3327    for connection_id in connection_ids.iter().cloned() {
3328        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3329            session
3330                .peer
3331                .forward_send(session.connection_id, connection_id, request.clone())?;
3332        }
3333    }
3334    Ok(())
3335}
3336
3337/// Get public data about users.
3338async fn get_users(
3339    request: proto::GetUsers,
3340    response: Response<proto::GetUsers>,
3341    session: Session,
3342) -> Result<()> {
3343    let user_ids = request
3344        .user_ids
3345        .into_iter()
3346        .map(UserId::from_proto)
3347        .collect();
3348    let users = session
3349        .db()
3350        .await
3351        .get_users_by_ids(user_ids)
3352        .await?
3353        .into_iter()
3354        .map(|user| proto::User {
3355            id: user.id.to_proto(),
3356            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3357            github_login: user.github_login,
3358        })
3359        .collect();
3360    response.send(proto::UsersResponse { users })?;
3361    Ok(())
3362}
3363
3364/// Search for users (to invite) buy Github login
3365async fn fuzzy_search_users(
3366    request: proto::FuzzySearchUsers,
3367    response: Response<proto::FuzzySearchUsers>,
3368    session: UserSession,
3369) -> Result<()> {
3370    let query = request.query;
3371    let users = match query.len() {
3372        0 => vec![],
3373        1 | 2 => session
3374            .db()
3375            .await
3376            .get_user_by_github_login(&query)
3377            .await?
3378            .into_iter()
3379            .collect(),
3380        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3381    };
3382    let users = users
3383        .into_iter()
3384        .filter(|user| user.id != session.user_id())
3385        .map(|user| proto::User {
3386            id: user.id.to_proto(),
3387            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3388            github_login: user.github_login,
3389        })
3390        .collect();
3391    response.send(proto::UsersResponse { users })?;
3392    Ok(())
3393}
3394
3395/// Send a contact request to another user.
3396async fn request_contact(
3397    request: proto::RequestContact,
3398    response: Response<proto::RequestContact>,
3399    session: UserSession,
3400) -> Result<()> {
3401    let requester_id = session.user_id();
3402    let responder_id = UserId::from_proto(request.responder_id);
3403    if requester_id == responder_id {
3404        return Err(anyhow!("cannot add yourself as a contact"))?;
3405    }
3406
3407    let notifications = session
3408        .db()
3409        .await
3410        .send_contact_request(requester_id, responder_id)
3411        .await?;
3412
3413    // Update outgoing contact requests of requester
3414    let mut update = proto::UpdateContacts::default();
3415    update.outgoing_requests.push(responder_id.to_proto());
3416    for connection_id in session
3417        .connection_pool()
3418        .await
3419        .user_connection_ids(requester_id)
3420    {
3421        session.peer.send(connection_id, update.clone())?;
3422    }
3423
3424    // Update incoming contact requests of responder
3425    let mut update = proto::UpdateContacts::default();
3426    update
3427        .incoming_requests
3428        .push(proto::IncomingContactRequest {
3429            requester_id: requester_id.to_proto(),
3430        });
3431    let connection_pool = session.connection_pool().await;
3432    for connection_id in connection_pool.user_connection_ids(responder_id) {
3433        session.peer.send(connection_id, update.clone())?;
3434    }
3435
3436    send_notifications(&connection_pool, &session.peer, notifications);
3437
3438    response.send(proto::Ack {})?;
3439    Ok(())
3440}
3441
3442/// Accept or decline a contact request
3443async fn respond_to_contact_request(
3444    request: proto::RespondToContactRequest,
3445    response: Response<proto::RespondToContactRequest>,
3446    session: UserSession,
3447) -> Result<()> {
3448    let responder_id = session.user_id();
3449    let requester_id = UserId::from_proto(request.requester_id);
3450    let db = session.db().await;
3451    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3452        db.dismiss_contact_notification(responder_id, requester_id)
3453            .await?;
3454    } else {
3455        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3456
3457        let notifications = db
3458            .respond_to_contact_request(responder_id, requester_id, accept)
3459            .await?;
3460        let requester_busy = db.is_user_busy(requester_id).await?;
3461        let responder_busy = db.is_user_busy(responder_id).await?;
3462
3463        let pool = session.connection_pool().await;
3464        // Update responder with new contact
3465        let mut update = proto::UpdateContacts::default();
3466        if accept {
3467            update
3468                .contacts
3469                .push(contact_for_user(requester_id, requester_busy, &pool));
3470        }
3471        update
3472            .remove_incoming_requests
3473            .push(requester_id.to_proto());
3474        for connection_id in pool.user_connection_ids(responder_id) {
3475            session.peer.send(connection_id, update.clone())?;
3476        }
3477
3478        // Update requester with new contact
3479        let mut update = proto::UpdateContacts::default();
3480        if accept {
3481            update
3482                .contacts
3483                .push(contact_for_user(responder_id, responder_busy, &pool));
3484        }
3485        update
3486            .remove_outgoing_requests
3487            .push(responder_id.to_proto());
3488
3489        for connection_id in pool.user_connection_ids(requester_id) {
3490            session.peer.send(connection_id, update.clone())?;
3491        }
3492
3493        send_notifications(&pool, &session.peer, notifications);
3494    }
3495
3496    response.send(proto::Ack {})?;
3497    Ok(())
3498}
3499
3500/// Remove a contact.
3501async fn remove_contact(
3502    request: proto::RemoveContact,
3503    response: Response<proto::RemoveContact>,
3504    session: UserSession,
3505) -> Result<()> {
3506    let requester_id = session.user_id();
3507    let responder_id = UserId::from_proto(request.user_id);
3508    let db = session.db().await;
3509    let (contact_accepted, deleted_notification_id) =
3510        db.remove_contact(requester_id, responder_id).await?;
3511
3512    let pool = session.connection_pool().await;
3513    // Update outgoing contact requests of requester
3514    let mut update = proto::UpdateContacts::default();
3515    if contact_accepted {
3516        update.remove_contacts.push(responder_id.to_proto());
3517    } else {
3518        update
3519            .remove_outgoing_requests
3520            .push(responder_id.to_proto());
3521    }
3522    for connection_id in pool.user_connection_ids(requester_id) {
3523        session.peer.send(connection_id, update.clone())?;
3524    }
3525
3526    // Update incoming contact requests of responder
3527    let mut update = proto::UpdateContacts::default();
3528    if contact_accepted {
3529        update.remove_contacts.push(requester_id.to_proto());
3530    } else {
3531        update
3532            .remove_incoming_requests
3533            .push(requester_id.to_proto());
3534    }
3535    for connection_id in pool.user_connection_ids(responder_id) {
3536        session.peer.send(connection_id, update.clone())?;
3537        if let Some(notification_id) = deleted_notification_id {
3538            session.peer.send(
3539                connection_id,
3540                proto::DeleteNotification {
3541                    notification_id: notification_id.to_proto(),
3542                },
3543            )?;
3544        }
3545    }
3546
3547    response.send(proto::Ack {})?;
3548    Ok(())
3549}
3550
3551fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3552    version.0.minor() < 139
3553}
3554
3555async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3556    let plan = session.current_plan(session.db().await).await?;
3557
3558    session
3559        .peer
3560        .send(
3561            session.connection_id,
3562            proto::UpdateUserPlan { plan: plan.into() },
3563        )
3564        .trace_err();
3565
3566    Ok(())
3567}
3568
3569async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3570    subscribe_user_to_channels(
3571        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3572        &session,
3573    )
3574    .await?;
3575    Ok(())
3576}
3577
3578async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3579    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3580    let mut pool = session.connection_pool().await;
3581    for membership in &channels_for_user.channel_memberships {
3582        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3583    }
3584    session.peer.send(
3585        session.connection_id,
3586        build_update_user_channels(&channels_for_user),
3587    )?;
3588    session.peer.send(
3589        session.connection_id,
3590        build_channels_update(channels_for_user),
3591    )?;
3592    Ok(())
3593}
3594
3595/// Creates a new channel.
3596async fn create_channel(
3597    request: proto::CreateChannel,
3598    response: Response<proto::CreateChannel>,
3599    session: UserSession,
3600) -> Result<()> {
3601    let db = session.db().await;
3602
3603    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3604    let (channel, membership) = db
3605        .create_channel(&request.name, parent_id, session.user_id())
3606        .await?;
3607
3608    let root_id = channel.root_id();
3609    let channel = Channel::from_model(channel);
3610
3611    response.send(proto::CreateChannelResponse {
3612        channel: Some(channel.to_proto()),
3613        parent_id: request.parent_id,
3614    })?;
3615
3616    let mut connection_pool = session.connection_pool().await;
3617    if let Some(membership) = membership {
3618        connection_pool.subscribe_to_channel(
3619            membership.user_id,
3620            membership.channel_id,
3621            membership.role,
3622        );
3623        let update = proto::UpdateUserChannels {
3624            channel_memberships: vec![proto::ChannelMembership {
3625                channel_id: membership.channel_id.to_proto(),
3626                role: membership.role.into(),
3627            }],
3628            ..Default::default()
3629        };
3630        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3631            session.peer.send(connection_id, update.clone())?;
3632        }
3633    }
3634
3635    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3636        if !role.can_see_channel(channel.visibility) {
3637            continue;
3638        }
3639
3640        let update = proto::UpdateChannels {
3641            channels: vec![channel.to_proto()],
3642            ..Default::default()
3643        };
3644        session.peer.send(connection_id, update.clone())?;
3645    }
3646
3647    Ok(())
3648}
3649
3650/// Delete a channel
3651async fn delete_channel(
3652    request: proto::DeleteChannel,
3653    response: Response<proto::DeleteChannel>,
3654    session: UserSession,
3655) -> Result<()> {
3656    let db = session.db().await;
3657
3658    let channel_id = request.channel_id;
3659    let (root_channel, removed_channels) = db
3660        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3661        .await?;
3662    response.send(proto::Ack {})?;
3663
3664    // Notify members of removed channels
3665    let mut update = proto::UpdateChannels::default();
3666    update
3667        .delete_channels
3668        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3669
3670    let connection_pool = session.connection_pool().await;
3671    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3672        session.peer.send(connection_id, update.clone())?;
3673    }
3674
3675    Ok(())
3676}
3677
3678/// Invite someone to join a channel.
3679async fn invite_channel_member(
3680    request: proto::InviteChannelMember,
3681    response: Response<proto::InviteChannelMember>,
3682    session: UserSession,
3683) -> Result<()> {
3684    let db = session.db().await;
3685    let channel_id = ChannelId::from_proto(request.channel_id);
3686    let invitee_id = UserId::from_proto(request.user_id);
3687    let InviteMemberResult {
3688        channel,
3689        notifications,
3690    } = db
3691        .invite_channel_member(
3692            channel_id,
3693            invitee_id,
3694            session.user_id(),
3695            request.role().into(),
3696        )
3697        .await?;
3698
3699    let update = proto::UpdateChannels {
3700        channel_invitations: vec![channel.to_proto()],
3701        ..Default::default()
3702    };
3703
3704    let connection_pool = session.connection_pool().await;
3705    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3706        session.peer.send(connection_id, update.clone())?;
3707    }
3708
3709    send_notifications(&connection_pool, &session.peer, notifications);
3710
3711    response.send(proto::Ack {})?;
3712    Ok(())
3713}
3714
3715/// remove someone from a channel
3716async fn remove_channel_member(
3717    request: proto::RemoveChannelMember,
3718    response: Response<proto::RemoveChannelMember>,
3719    session: UserSession,
3720) -> Result<()> {
3721    let db = session.db().await;
3722    let channel_id = ChannelId::from_proto(request.channel_id);
3723    let member_id = UserId::from_proto(request.user_id);
3724
3725    let RemoveChannelMemberResult {
3726        membership_update,
3727        notification_id,
3728    } = db
3729        .remove_channel_member(channel_id, member_id, session.user_id())
3730        .await?;
3731
3732    let mut connection_pool = session.connection_pool().await;
3733    notify_membership_updated(
3734        &mut connection_pool,
3735        membership_update,
3736        member_id,
3737        &session.peer,
3738    );
3739    for connection_id in connection_pool.user_connection_ids(member_id) {
3740        if let Some(notification_id) = notification_id {
3741            session
3742                .peer
3743                .send(
3744                    connection_id,
3745                    proto::DeleteNotification {
3746                        notification_id: notification_id.to_proto(),
3747                    },
3748                )
3749                .trace_err();
3750        }
3751    }
3752
3753    response.send(proto::Ack {})?;
3754    Ok(())
3755}
3756
3757/// Toggle the channel between public and private.
3758/// Care is taken to maintain the invariant that public channels only descend from public channels,
3759/// (though members-only channels can appear at any point in the hierarchy).
3760async fn set_channel_visibility(
3761    request: proto::SetChannelVisibility,
3762    response: Response<proto::SetChannelVisibility>,
3763    session: UserSession,
3764) -> Result<()> {
3765    let db = session.db().await;
3766    let channel_id = ChannelId::from_proto(request.channel_id);
3767    let visibility = request.visibility().into();
3768
3769    let channel_model = db
3770        .set_channel_visibility(channel_id, visibility, session.user_id())
3771        .await?;
3772    let root_id = channel_model.root_id();
3773    let channel = Channel::from_model(channel_model);
3774
3775    let mut connection_pool = session.connection_pool().await;
3776    for (user_id, role) in connection_pool
3777        .channel_user_ids(root_id)
3778        .collect::<Vec<_>>()
3779        .into_iter()
3780    {
3781        let update = if role.can_see_channel(channel.visibility) {
3782            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3783            proto::UpdateChannels {
3784                channels: vec![channel.to_proto()],
3785                ..Default::default()
3786            }
3787        } else {
3788            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3789            proto::UpdateChannels {
3790                delete_channels: vec![channel.id.to_proto()],
3791                ..Default::default()
3792            }
3793        };
3794
3795        for connection_id in connection_pool.user_connection_ids(user_id) {
3796            session.peer.send(connection_id, update.clone())?;
3797        }
3798    }
3799
3800    response.send(proto::Ack {})?;
3801    Ok(())
3802}
3803
3804/// Alter the role for a user in the channel.
3805async fn set_channel_member_role(
3806    request: proto::SetChannelMemberRole,
3807    response: Response<proto::SetChannelMemberRole>,
3808    session: UserSession,
3809) -> Result<()> {
3810    let db = session.db().await;
3811    let channel_id = ChannelId::from_proto(request.channel_id);
3812    let member_id = UserId::from_proto(request.user_id);
3813    let result = db
3814        .set_channel_member_role(
3815            channel_id,
3816            session.user_id(),
3817            member_id,
3818            request.role().into(),
3819        )
3820        .await?;
3821
3822    match result {
3823        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3824            let mut connection_pool = session.connection_pool().await;
3825            notify_membership_updated(
3826                &mut connection_pool,
3827                membership_update,
3828                member_id,
3829                &session.peer,
3830            )
3831        }
3832        db::SetMemberRoleResult::InviteUpdated(channel) => {
3833            let update = proto::UpdateChannels {
3834                channel_invitations: vec![channel.to_proto()],
3835                ..Default::default()
3836            };
3837
3838            for connection_id in session
3839                .connection_pool()
3840                .await
3841                .user_connection_ids(member_id)
3842            {
3843                session.peer.send(connection_id, update.clone())?;
3844            }
3845        }
3846    }
3847
3848    response.send(proto::Ack {})?;
3849    Ok(())
3850}
3851
3852/// Change the name of a channel
3853async fn rename_channel(
3854    request: proto::RenameChannel,
3855    response: Response<proto::RenameChannel>,
3856    session: UserSession,
3857) -> Result<()> {
3858    let db = session.db().await;
3859    let channel_id = ChannelId::from_proto(request.channel_id);
3860    let channel_model = db
3861        .rename_channel(channel_id, session.user_id(), &request.name)
3862        .await?;
3863    let root_id = channel_model.root_id();
3864    let channel = Channel::from_model(channel_model);
3865
3866    response.send(proto::RenameChannelResponse {
3867        channel: Some(channel.to_proto()),
3868    })?;
3869
3870    let connection_pool = session.connection_pool().await;
3871    let update = proto::UpdateChannels {
3872        channels: vec![channel.to_proto()],
3873        ..Default::default()
3874    };
3875    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3876        if role.can_see_channel(channel.visibility) {
3877            session.peer.send(connection_id, update.clone())?;
3878        }
3879    }
3880
3881    Ok(())
3882}
3883
3884/// Move a channel to a new parent.
3885async fn move_channel(
3886    request: proto::MoveChannel,
3887    response: Response<proto::MoveChannel>,
3888    session: UserSession,
3889) -> Result<()> {
3890    let channel_id = ChannelId::from_proto(request.channel_id);
3891    let to = ChannelId::from_proto(request.to);
3892
3893    let (root_id, channels) = session
3894        .db()
3895        .await
3896        .move_channel(channel_id, to, session.user_id())
3897        .await?;
3898
3899    let connection_pool = session.connection_pool().await;
3900    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3901        let channels = channels
3902            .iter()
3903            .filter_map(|channel| {
3904                if role.can_see_channel(channel.visibility) {
3905                    Some(channel.to_proto())
3906                } else {
3907                    None
3908                }
3909            })
3910            .collect::<Vec<_>>();
3911        if channels.is_empty() {
3912            continue;
3913        }
3914
3915        let update = proto::UpdateChannels {
3916            channels,
3917            ..Default::default()
3918        };
3919
3920        session.peer.send(connection_id, update.clone())?;
3921    }
3922
3923    response.send(Ack {})?;
3924    Ok(())
3925}
3926
3927/// Get the list of channel members
3928async fn get_channel_members(
3929    request: proto::GetChannelMembers,
3930    response: Response<proto::GetChannelMembers>,
3931    session: UserSession,
3932) -> Result<()> {
3933    let db = session.db().await;
3934    let channel_id = ChannelId::from_proto(request.channel_id);
3935    let limit = if request.limit == 0 {
3936        u16::MAX as u64
3937    } else {
3938        request.limit
3939    };
3940    let (members, users) = db
3941        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3942        .await?;
3943    response.send(proto::GetChannelMembersResponse { members, users })?;
3944    Ok(())
3945}
3946
3947/// Accept or decline a channel invitation.
3948async fn respond_to_channel_invite(
3949    request: proto::RespondToChannelInvite,
3950    response: Response<proto::RespondToChannelInvite>,
3951    session: UserSession,
3952) -> Result<()> {
3953    let db = session.db().await;
3954    let channel_id = ChannelId::from_proto(request.channel_id);
3955    let RespondToChannelInvite {
3956        membership_update,
3957        notifications,
3958    } = db
3959        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3960        .await?;
3961
3962    let mut connection_pool = session.connection_pool().await;
3963    if let Some(membership_update) = membership_update {
3964        notify_membership_updated(
3965            &mut connection_pool,
3966            membership_update,
3967            session.user_id(),
3968            &session.peer,
3969        );
3970    } else {
3971        let update = proto::UpdateChannels {
3972            remove_channel_invitations: vec![channel_id.to_proto()],
3973            ..Default::default()
3974        };
3975
3976        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3977            session.peer.send(connection_id, update.clone())?;
3978        }
3979    };
3980
3981    send_notifications(&connection_pool, &session.peer, notifications);
3982
3983    response.send(proto::Ack {})?;
3984
3985    Ok(())
3986}
3987
3988/// Join the channels' room
3989async fn join_channel(
3990    request: proto::JoinChannel,
3991    response: Response<proto::JoinChannel>,
3992    session: UserSession,
3993) -> Result<()> {
3994    let channel_id = ChannelId::from_proto(request.channel_id);
3995    join_channel_internal(channel_id, Box::new(response), session).await
3996}
3997
3998trait JoinChannelInternalResponse {
3999    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4000}
4001impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4002    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4003        Response::<proto::JoinChannel>::send(self, result)
4004    }
4005}
4006impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4007    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4008        Response::<proto::JoinRoom>::send(self, result)
4009    }
4010}
4011
4012async fn join_channel_internal(
4013    channel_id: ChannelId,
4014    response: Box<impl JoinChannelInternalResponse>,
4015    session: UserSession,
4016) -> Result<()> {
4017    let joined_room = {
4018        let mut db = session.db().await;
4019        // If zed quits without leaving the room, and the user re-opens zed before the
4020        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4021        // room they were in.
4022        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4023            tracing::info!(
4024                stale_connection_id = %connection,
4025                "cleaning up stale connection",
4026            );
4027            drop(db);
4028            leave_room_for_session(&session, connection).await?;
4029            db = session.db().await;
4030        }
4031
4032        let (joined_room, membership_updated, role) = db
4033            .join_channel(channel_id, session.user_id(), session.connection_id)
4034            .await?;
4035
4036        let live_kit_connection_info =
4037            session
4038                .app_state
4039                .live_kit_client
4040                .as_ref()
4041                .and_then(|live_kit| {
4042                    let (can_publish, token) = if role == ChannelRole::Guest {
4043                        (
4044                            false,
4045                            live_kit
4046                                .guest_token(
4047                                    &joined_room.room.live_kit_room,
4048                                    &session.user_id().to_string(),
4049                                )
4050                                .trace_err()?,
4051                        )
4052                    } else {
4053                        (
4054                            true,
4055                            live_kit
4056                                .room_token(
4057                                    &joined_room.room.live_kit_room,
4058                                    &session.user_id().to_string(),
4059                                )
4060                                .trace_err()?,
4061                        )
4062                    };
4063
4064                    Some(LiveKitConnectionInfo {
4065                        server_url: live_kit.url().into(),
4066                        token,
4067                        can_publish,
4068                    })
4069                });
4070
4071        response.send(proto::JoinRoomResponse {
4072            room: Some(joined_room.room.clone()),
4073            channel_id: joined_room
4074                .channel
4075                .as_ref()
4076                .map(|channel| channel.id.to_proto()),
4077            live_kit_connection_info,
4078        })?;
4079
4080        let mut connection_pool = session.connection_pool().await;
4081        if let Some(membership_updated) = membership_updated {
4082            notify_membership_updated(
4083                &mut connection_pool,
4084                membership_updated,
4085                session.user_id(),
4086                &session.peer,
4087            );
4088        }
4089
4090        room_updated(&joined_room.room, &session.peer);
4091
4092        joined_room
4093    };
4094
4095    channel_updated(
4096        &joined_room
4097            .channel
4098            .ok_or_else(|| anyhow!("channel not returned"))?,
4099        &joined_room.room,
4100        &session.peer,
4101        &*session.connection_pool().await,
4102    );
4103
4104    update_user_contacts(session.user_id(), &session).await?;
4105    Ok(())
4106}
4107
4108/// Start editing the channel notes
4109async fn join_channel_buffer(
4110    request: proto::JoinChannelBuffer,
4111    response: Response<proto::JoinChannelBuffer>,
4112    session: UserSession,
4113) -> Result<()> {
4114    let db = session.db().await;
4115    let channel_id = ChannelId::from_proto(request.channel_id);
4116
4117    let open_response = db
4118        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4119        .await?;
4120
4121    let collaborators = open_response.collaborators.clone();
4122    response.send(open_response)?;
4123
4124    let update = UpdateChannelBufferCollaborators {
4125        channel_id: channel_id.to_proto(),
4126        collaborators: collaborators.clone(),
4127    };
4128    channel_buffer_updated(
4129        session.connection_id,
4130        collaborators
4131            .iter()
4132            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4133        &update,
4134        &session.peer,
4135    );
4136
4137    Ok(())
4138}
4139
4140/// Edit the channel notes
4141async fn update_channel_buffer(
4142    request: proto::UpdateChannelBuffer,
4143    session: UserSession,
4144) -> Result<()> {
4145    let db = session.db().await;
4146    let channel_id = ChannelId::from_proto(request.channel_id);
4147
4148    let (collaborators, epoch, version) = db
4149        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4150        .await?;
4151
4152    channel_buffer_updated(
4153        session.connection_id,
4154        collaborators.clone(),
4155        &proto::UpdateChannelBuffer {
4156            channel_id: channel_id.to_proto(),
4157            operations: request.operations,
4158        },
4159        &session.peer,
4160    );
4161
4162    let pool = &*session.connection_pool().await;
4163
4164    let non_collaborators =
4165        pool.channel_connection_ids(channel_id)
4166            .filter_map(|(connection_id, _)| {
4167                if collaborators.contains(&connection_id) {
4168                    None
4169                } else {
4170                    Some(connection_id)
4171                }
4172            });
4173
4174    broadcast(None, non_collaborators, |peer_id| {
4175        session.peer.send(
4176            peer_id,
4177            proto::UpdateChannels {
4178                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4179                    channel_id: channel_id.to_proto(),
4180                    epoch: epoch as u64,
4181                    version: version.clone(),
4182                }],
4183                ..Default::default()
4184            },
4185        )
4186    });
4187
4188    Ok(())
4189}
4190
4191/// Rejoin the channel notes after a connection blip
4192async fn rejoin_channel_buffers(
4193    request: proto::RejoinChannelBuffers,
4194    response: Response<proto::RejoinChannelBuffers>,
4195    session: UserSession,
4196) -> Result<()> {
4197    let db = session.db().await;
4198    let buffers = db
4199        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4200        .await?;
4201
4202    for rejoined_buffer in &buffers {
4203        let collaborators_to_notify = rejoined_buffer
4204            .buffer
4205            .collaborators
4206            .iter()
4207            .filter_map(|c| Some(c.peer_id?.into()));
4208        channel_buffer_updated(
4209            session.connection_id,
4210            collaborators_to_notify,
4211            &proto::UpdateChannelBufferCollaborators {
4212                channel_id: rejoined_buffer.buffer.channel_id,
4213                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4214            },
4215            &session.peer,
4216        );
4217    }
4218
4219    response.send(proto::RejoinChannelBuffersResponse {
4220        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4221    })?;
4222
4223    Ok(())
4224}
4225
4226/// Stop editing the channel notes
4227async fn leave_channel_buffer(
4228    request: proto::LeaveChannelBuffer,
4229    response: Response<proto::LeaveChannelBuffer>,
4230    session: UserSession,
4231) -> Result<()> {
4232    let db = session.db().await;
4233    let channel_id = ChannelId::from_proto(request.channel_id);
4234
4235    let left_buffer = db
4236        .leave_channel_buffer(channel_id, session.connection_id)
4237        .await?;
4238
4239    response.send(Ack {})?;
4240
4241    channel_buffer_updated(
4242        session.connection_id,
4243        left_buffer.connections,
4244        &proto::UpdateChannelBufferCollaborators {
4245            channel_id: channel_id.to_proto(),
4246            collaborators: left_buffer.collaborators,
4247        },
4248        &session.peer,
4249    );
4250
4251    Ok(())
4252}
4253
4254fn channel_buffer_updated<T: EnvelopedMessage>(
4255    sender_id: ConnectionId,
4256    collaborators: impl IntoIterator<Item = ConnectionId>,
4257    message: &T,
4258    peer: &Peer,
4259) {
4260    broadcast(Some(sender_id), collaborators, |peer_id| {
4261        peer.send(peer_id, message.clone())
4262    });
4263}
4264
4265fn send_notifications(
4266    connection_pool: &ConnectionPool,
4267    peer: &Peer,
4268    notifications: db::NotificationBatch,
4269) {
4270    for (user_id, notification) in notifications {
4271        for connection_id in connection_pool.user_connection_ids(user_id) {
4272            if let Err(error) = peer.send(
4273                connection_id,
4274                proto::AddNotification {
4275                    notification: Some(notification.clone()),
4276                },
4277            ) {
4278                tracing::error!(
4279                    "failed to send notification to {:?} {}",
4280                    connection_id,
4281                    error
4282                );
4283            }
4284        }
4285    }
4286}
4287
4288/// Send a message to the channel
4289async fn send_channel_message(
4290    request: proto::SendChannelMessage,
4291    response: Response<proto::SendChannelMessage>,
4292    session: UserSession,
4293) -> Result<()> {
4294    // Validate the message body.
4295    let body = request.body.trim().to_string();
4296    if body.len() > MAX_MESSAGE_LEN {
4297        return Err(anyhow!("message is too long"))?;
4298    }
4299    if body.is_empty() {
4300        return Err(anyhow!("message can't be blank"))?;
4301    }
4302
4303    // TODO: adjust mentions if body is trimmed
4304
4305    let timestamp = OffsetDateTime::now_utc();
4306    let nonce = request
4307        .nonce
4308        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4309
4310    let channel_id = ChannelId::from_proto(request.channel_id);
4311    let CreatedChannelMessage {
4312        message_id,
4313        participant_connection_ids,
4314        notifications,
4315    } = session
4316        .db()
4317        .await
4318        .create_channel_message(
4319            channel_id,
4320            session.user_id(),
4321            &body,
4322            &request.mentions,
4323            timestamp,
4324            nonce.clone().into(),
4325            match request.reply_to_message_id {
4326                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4327                None => None,
4328            },
4329        )
4330        .await?;
4331
4332    let message = proto::ChannelMessage {
4333        sender_id: session.user_id().to_proto(),
4334        id: message_id.to_proto(),
4335        body,
4336        mentions: request.mentions,
4337        timestamp: timestamp.unix_timestamp() as u64,
4338        nonce: Some(nonce),
4339        reply_to_message_id: request.reply_to_message_id,
4340        edited_at: None,
4341    };
4342    broadcast(
4343        Some(session.connection_id),
4344        participant_connection_ids.clone(),
4345        |connection| {
4346            session.peer.send(
4347                connection,
4348                proto::ChannelMessageSent {
4349                    channel_id: channel_id.to_proto(),
4350                    message: Some(message.clone()),
4351                },
4352            )
4353        },
4354    );
4355    response.send(proto::SendChannelMessageResponse {
4356        message: Some(message),
4357    })?;
4358
4359    let pool = &*session.connection_pool().await;
4360    let non_participants =
4361        pool.channel_connection_ids(channel_id)
4362            .filter_map(|(connection_id, _)| {
4363                if participant_connection_ids.contains(&connection_id) {
4364                    None
4365                } else {
4366                    Some(connection_id)
4367                }
4368            });
4369    broadcast(None, non_participants, |peer_id| {
4370        session.peer.send(
4371            peer_id,
4372            proto::UpdateChannels {
4373                latest_channel_message_ids: vec![proto::ChannelMessageId {
4374                    channel_id: channel_id.to_proto(),
4375                    message_id: message_id.to_proto(),
4376                }],
4377                ..Default::default()
4378            },
4379        )
4380    });
4381    send_notifications(pool, &session.peer, notifications);
4382
4383    Ok(())
4384}
4385
4386/// Delete a channel message
4387async fn remove_channel_message(
4388    request: proto::RemoveChannelMessage,
4389    response: Response<proto::RemoveChannelMessage>,
4390    session: UserSession,
4391) -> Result<()> {
4392    let channel_id = ChannelId::from_proto(request.channel_id);
4393    let message_id = MessageId::from_proto(request.message_id);
4394    let (connection_ids, existing_notification_ids) = session
4395        .db()
4396        .await
4397        .remove_channel_message(channel_id, message_id, session.user_id())
4398        .await?;
4399
4400    broadcast(
4401        Some(session.connection_id),
4402        connection_ids,
4403        move |connection| {
4404            session.peer.send(connection, request.clone())?;
4405
4406            for notification_id in &existing_notification_ids {
4407                session.peer.send(
4408                    connection,
4409                    proto::DeleteNotification {
4410                        notification_id: (*notification_id).to_proto(),
4411                    },
4412                )?;
4413            }
4414
4415            Ok(())
4416        },
4417    );
4418    response.send(proto::Ack {})?;
4419    Ok(())
4420}
4421
4422async fn update_channel_message(
4423    request: proto::UpdateChannelMessage,
4424    response: Response<proto::UpdateChannelMessage>,
4425    session: UserSession,
4426) -> Result<()> {
4427    let channel_id = ChannelId::from_proto(request.channel_id);
4428    let message_id = MessageId::from_proto(request.message_id);
4429    let updated_at = OffsetDateTime::now_utc();
4430    let UpdatedChannelMessage {
4431        message_id,
4432        participant_connection_ids,
4433        notifications,
4434        reply_to_message_id,
4435        timestamp,
4436        deleted_mention_notification_ids,
4437        updated_mention_notifications,
4438    } = session
4439        .db()
4440        .await
4441        .update_channel_message(
4442            channel_id,
4443            message_id,
4444            session.user_id(),
4445            request.body.as_str(),
4446            &request.mentions,
4447            updated_at,
4448        )
4449        .await?;
4450
4451    let nonce = request
4452        .nonce
4453        .clone()
4454        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4455
4456    let message = proto::ChannelMessage {
4457        sender_id: session.user_id().to_proto(),
4458        id: message_id.to_proto(),
4459        body: request.body.clone(),
4460        mentions: request.mentions.clone(),
4461        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4462        nonce: Some(nonce),
4463        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4464        edited_at: Some(updated_at.unix_timestamp() as u64),
4465    };
4466
4467    response.send(proto::Ack {})?;
4468
4469    let pool = &*session.connection_pool().await;
4470    broadcast(
4471        Some(session.connection_id),
4472        participant_connection_ids,
4473        |connection| {
4474            session.peer.send(
4475                connection,
4476                proto::ChannelMessageUpdate {
4477                    channel_id: channel_id.to_proto(),
4478                    message: Some(message.clone()),
4479                },
4480            )?;
4481
4482            for notification_id in &deleted_mention_notification_ids {
4483                session.peer.send(
4484                    connection,
4485                    proto::DeleteNotification {
4486                        notification_id: (*notification_id).to_proto(),
4487                    },
4488                )?;
4489            }
4490
4491            for notification in &updated_mention_notifications {
4492                session.peer.send(
4493                    connection,
4494                    proto::UpdateNotification {
4495                        notification: Some(notification.clone()),
4496                    },
4497                )?;
4498            }
4499
4500            Ok(())
4501        },
4502    );
4503
4504    send_notifications(pool, &session.peer, notifications);
4505
4506    Ok(())
4507}
4508
4509/// Mark a channel message as read
4510async fn acknowledge_channel_message(
4511    request: proto::AckChannelMessage,
4512    session: UserSession,
4513) -> Result<()> {
4514    let channel_id = ChannelId::from_proto(request.channel_id);
4515    let message_id = MessageId::from_proto(request.message_id);
4516    let notifications = session
4517        .db()
4518        .await
4519        .observe_channel_message(channel_id, session.user_id(), message_id)
4520        .await?;
4521    send_notifications(
4522        &*session.connection_pool().await,
4523        &session.peer,
4524        notifications,
4525    );
4526    Ok(())
4527}
4528
4529/// Mark a buffer version as synced
4530async fn acknowledge_buffer_version(
4531    request: proto::AckBufferOperation,
4532    session: UserSession,
4533) -> Result<()> {
4534    let buffer_id = BufferId::from_proto(request.buffer_id);
4535    session
4536        .db()
4537        .await
4538        .observe_buffer_version(
4539            buffer_id,
4540            session.user_id(),
4541            request.epoch as i32,
4542            &request.version,
4543        )
4544        .await?;
4545    Ok(())
4546}
4547
4548async fn count_language_model_tokens(
4549    request: proto::CountLanguageModelTokens,
4550    response: Response<proto::CountLanguageModelTokens>,
4551    session: Session,
4552    config: &Config,
4553) -> Result<()> {
4554    let Some(session) = session.for_user() else {
4555        return Err(anyhow!("user not found"))?;
4556    };
4557    authorize_access_to_legacy_llm_endpoints(&session).await?;
4558
4559    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4560        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4561        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4562    };
4563
4564    session
4565        .app_state
4566        .rate_limiter
4567        .check(&*rate_limit, session.user_id())
4568        .await?;
4569
4570    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4571        Some(proto::LanguageModelProvider::Google) => {
4572            let api_key = config
4573                .google_ai_api_key
4574                .as_ref()
4575                .context("no Google AI API key configured on the server")?;
4576            google_ai::count_tokens(
4577                session.http_client.as_ref(),
4578                google_ai::API_URL,
4579                api_key,
4580                serde_json::from_str(&request.request)?,
4581            )
4582            .await?
4583        }
4584        _ => return Err(anyhow!("unsupported provider"))?,
4585    };
4586
4587    response.send(proto::CountLanguageModelTokensResponse {
4588        token_count: result.total_tokens as u32,
4589    })?;
4590
4591    Ok(())
4592}
4593
4594struct ZedProCountLanguageModelTokensRateLimit;
4595
4596impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4597    fn capacity(&self) -> usize {
4598        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4599            .ok()
4600            .and_then(|v| v.parse().ok())
4601            .unwrap_or(600) // Picked arbitrarily
4602    }
4603
4604    fn refill_duration(&self) -> chrono::Duration {
4605        chrono::Duration::hours(1)
4606    }
4607
4608    fn db_name(&self) -> &'static str {
4609        "zed-pro:count-language-model-tokens"
4610    }
4611}
4612
4613struct FreeCountLanguageModelTokensRateLimit;
4614
4615impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4616    fn capacity(&self) -> usize {
4617        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4618            .ok()
4619            .and_then(|v| v.parse().ok())
4620            .unwrap_or(600 / 10) // Picked arbitrarily
4621    }
4622
4623    fn refill_duration(&self) -> chrono::Duration {
4624        chrono::Duration::hours(1)
4625    }
4626
4627    fn db_name(&self) -> &'static str {
4628        "free:count-language-model-tokens"
4629    }
4630}
4631
4632struct ZedProComputeEmbeddingsRateLimit;
4633
4634impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4635    fn capacity(&self) -> usize {
4636        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4637            .ok()
4638            .and_then(|v| v.parse().ok())
4639            .unwrap_or(5000) // Picked arbitrarily
4640    }
4641
4642    fn refill_duration(&self) -> chrono::Duration {
4643        chrono::Duration::hours(1)
4644    }
4645
4646    fn db_name(&self) -> &'static str {
4647        "zed-pro:compute-embeddings"
4648    }
4649}
4650
4651struct FreeComputeEmbeddingsRateLimit;
4652
4653impl RateLimit for FreeComputeEmbeddingsRateLimit {
4654    fn capacity(&self) -> usize {
4655        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4656            .ok()
4657            .and_then(|v| v.parse().ok())
4658            .unwrap_or(5000 / 10) // Picked arbitrarily
4659    }
4660
4661    fn refill_duration(&self) -> chrono::Duration {
4662        chrono::Duration::hours(1)
4663    }
4664
4665    fn db_name(&self) -> &'static str {
4666        "free:compute-embeddings"
4667    }
4668}
4669
4670async fn compute_embeddings(
4671    request: proto::ComputeEmbeddings,
4672    response: Response<proto::ComputeEmbeddings>,
4673    session: UserSession,
4674    api_key: Option<Arc<str>>,
4675) -> Result<()> {
4676    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4677    authorize_access_to_legacy_llm_endpoints(&session).await?;
4678
4679    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4680        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4681        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4682    };
4683
4684    session
4685        .app_state
4686        .rate_limiter
4687        .check(&*rate_limit, session.user_id())
4688        .await?;
4689
4690    let embeddings = match request.model.as_str() {
4691        "openai/text-embedding-3-small" => {
4692            open_ai::embed(
4693                session.http_client.as_ref(),
4694                OPEN_AI_API_URL,
4695                &api_key,
4696                OpenAiEmbeddingModel::TextEmbedding3Small,
4697                request.texts.iter().map(|text| text.as_str()),
4698            )
4699            .await?
4700        }
4701        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4702    };
4703
4704    let embeddings = request
4705        .texts
4706        .iter()
4707        .map(|text| {
4708            let mut hasher = sha2::Sha256::new();
4709            hasher.update(text.as_bytes());
4710            let result = hasher.finalize();
4711            result.to_vec()
4712        })
4713        .zip(
4714            embeddings
4715                .data
4716                .into_iter()
4717                .map(|embedding| embedding.embedding),
4718        )
4719        .collect::<HashMap<_, _>>();
4720
4721    let db = session.db().await;
4722    db.save_embeddings(&request.model, &embeddings)
4723        .await
4724        .context("failed to save embeddings")
4725        .trace_err();
4726
4727    response.send(proto::ComputeEmbeddingsResponse {
4728        embeddings: embeddings
4729            .into_iter()
4730            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4731            .collect(),
4732    })?;
4733    Ok(())
4734}
4735
4736async fn get_cached_embeddings(
4737    request: proto::GetCachedEmbeddings,
4738    response: Response<proto::GetCachedEmbeddings>,
4739    session: UserSession,
4740) -> Result<()> {
4741    authorize_access_to_legacy_llm_endpoints(&session).await?;
4742
4743    let db = session.db().await;
4744    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4745
4746    response.send(proto::GetCachedEmbeddingsResponse {
4747        embeddings: embeddings
4748            .into_iter()
4749            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4750            .collect(),
4751    })?;
4752    Ok(())
4753}
4754
4755/// This is leftover from before the LLM service.
4756///
4757/// The endpoints protected by this check will be moved there eventually.
4758async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4759    if session.is_staff() {
4760        Ok(())
4761    } else {
4762        Err(anyhow!("permission denied"))?
4763    }
4764}
4765
4766/// Get a Supermaven API key for the user
4767async fn get_supermaven_api_key(
4768    _request: proto::GetSupermavenApiKey,
4769    response: Response<proto::GetSupermavenApiKey>,
4770    session: UserSession,
4771) -> Result<()> {
4772    let user_id: String = session.user_id().to_string();
4773    if !session.is_staff() {
4774        return Err(anyhow!("supermaven not enabled for this account"))?;
4775    }
4776
4777    let email = session
4778        .email()
4779        .ok_or_else(|| anyhow!("user must have an email"))?;
4780
4781    let supermaven_admin_api = session
4782        .supermaven_client
4783        .as_ref()
4784        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4785
4786    let result = supermaven_admin_api
4787        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4788        .await?;
4789
4790    response.send(proto::GetSupermavenApiKeyResponse {
4791        api_key: result.api_key,
4792    })?;
4793
4794    Ok(())
4795}
4796
4797/// Start receiving chat updates for a channel
4798async fn join_channel_chat(
4799    request: proto::JoinChannelChat,
4800    response: Response<proto::JoinChannelChat>,
4801    session: UserSession,
4802) -> Result<()> {
4803    let channel_id = ChannelId::from_proto(request.channel_id);
4804
4805    let db = session.db().await;
4806    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4807        .await?;
4808    let messages = db
4809        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4810        .await?;
4811    response.send(proto::JoinChannelChatResponse {
4812        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4813        messages,
4814    })?;
4815    Ok(())
4816}
4817
4818/// Stop receiving chat updates for a channel
4819async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4820    let channel_id = ChannelId::from_proto(request.channel_id);
4821    session
4822        .db()
4823        .await
4824        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4825        .await?;
4826    Ok(())
4827}
4828
4829/// Retrieve the chat history for a channel
4830async fn get_channel_messages(
4831    request: proto::GetChannelMessages,
4832    response: Response<proto::GetChannelMessages>,
4833    session: UserSession,
4834) -> Result<()> {
4835    let channel_id = ChannelId::from_proto(request.channel_id);
4836    let messages = session
4837        .db()
4838        .await
4839        .get_channel_messages(
4840            channel_id,
4841            session.user_id(),
4842            MESSAGE_COUNT_PER_PAGE,
4843            Some(MessageId::from_proto(request.before_message_id)),
4844        )
4845        .await?;
4846    response.send(proto::GetChannelMessagesResponse {
4847        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4848        messages,
4849    })?;
4850    Ok(())
4851}
4852
4853/// Retrieve specific chat messages
4854async fn get_channel_messages_by_id(
4855    request: proto::GetChannelMessagesById,
4856    response: Response<proto::GetChannelMessagesById>,
4857    session: UserSession,
4858) -> Result<()> {
4859    let message_ids = request
4860        .message_ids
4861        .iter()
4862        .map(|id| MessageId::from_proto(*id))
4863        .collect::<Vec<_>>();
4864    let messages = session
4865        .db()
4866        .await
4867        .get_channel_messages_by_id(session.user_id(), &message_ids)
4868        .await?;
4869    response.send(proto::GetChannelMessagesResponse {
4870        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4871        messages,
4872    })?;
4873    Ok(())
4874}
4875
4876/// Retrieve the current users notifications
4877async fn get_notifications(
4878    request: proto::GetNotifications,
4879    response: Response<proto::GetNotifications>,
4880    session: UserSession,
4881) -> Result<()> {
4882    let notifications = session
4883        .db()
4884        .await
4885        .get_notifications(
4886            session.user_id(),
4887            NOTIFICATION_COUNT_PER_PAGE,
4888            request
4889                .before_id
4890                .map(|id| db::NotificationId::from_proto(id)),
4891        )
4892        .await?;
4893    response.send(proto::GetNotificationsResponse {
4894        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4895        notifications,
4896    })?;
4897    Ok(())
4898}
4899
4900/// Mark notifications as read
4901async fn mark_notification_as_read(
4902    request: proto::MarkNotificationRead,
4903    response: Response<proto::MarkNotificationRead>,
4904    session: UserSession,
4905) -> Result<()> {
4906    let database = &session.db().await;
4907    let notifications = database
4908        .mark_notification_as_read_by_id(
4909            session.user_id(),
4910            NotificationId::from_proto(request.notification_id),
4911        )
4912        .await?;
4913    send_notifications(
4914        &*session.connection_pool().await,
4915        &session.peer,
4916        notifications,
4917    );
4918    response.send(proto::Ack {})?;
4919    Ok(())
4920}
4921
4922/// Get the current users information
4923async fn get_private_user_info(
4924    _request: proto::GetPrivateUserInfo,
4925    response: Response<proto::GetPrivateUserInfo>,
4926    session: UserSession,
4927) -> Result<()> {
4928    let db = session.db().await;
4929
4930    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4931    let user = db
4932        .get_user_by_id(session.user_id())
4933        .await?
4934        .ok_or_else(|| anyhow!("user not found"))?;
4935    let flags = db.get_user_flags(session.user_id()).await?;
4936
4937    response.send(proto::GetPrivateUserInfoResponse {
4938        metrics_id,
4939        staff: user.admin,
4940        flags,
4941        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4942    })?;
4943    Ok(())
4944}
4945
4946/// Accept the terms of service (tos) on behalf of the current user
4947async fn accept_terms_of_service(
4948    _request: proto::AcceptTermsOfService,
4949    response: Response<proto::AcceptTermsOfService>,
4950    session: UserSession,
4951) -> Result<()> {
4952    let db = session.db().await;
4953
4954    let accepted_tos_at = Utc::now();
4955    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4956        .await?;
4957
4958    response.send(proto::AcceptTermsOfServiceResponse {
4959        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4960    })?;
4961    Ok(())
4962}
4963
4964/// The minimum account age an account must have in order to use the LLM service.
4965const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4966
4967async fn get_llm_api_token(
4968    _request: proto::GetLlmToken,
4969    response: Response<proto::GetLlmToken>,
4970    session: UserSession,
4971) -> Result<()> {
4972    let db = session.db().await;
4973
4974    let flags = db.get_user_flags(session.user_id()).await?;
4975    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4976    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4977
4978    if !session.is_staff() && !has_language_models_feature_flag {
4979        Err(anyhow!("permission denied"))?
4980    }
4981
4982    let user_id = session.user_id();
4983    let user = db
4984        .get_user_by_id(user_id)
4985        .await?
4986        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4987
4988    if user.accepted_tos_at.is_none() {
4989        Err(anyhow!("terms of service not accepted"))?
4990    }
4991
4992    let mut account_created_at = user.created_at;
4993    if let Some(github_created_at) = user.github_user_created_at {
4994        account_created_at = account_created_at.min(github_created_at);
4995    }
4996    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4997        Err(anyhow!("account too young"))?
4998    }
4999    let token = LlmTokenClaims::create(
5000        user.id,
5001        user.github_login.clone(),
5002        session.is_staff(),
5003        has_llm_closed_beta_feature_flag,
5004        session.current_plan(db).await?,
5005        &session.app_state.config,
5006    )?;
5007    response.send(proto::GetLlmTokenResponse { token })?;
5008    Ok(())
5009}
5010
5011fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5012    let message = match message {
5013        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5014        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5015        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5016        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5017        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5018            code: frame.code.into(),
5019            reason: frame.reason,
5020        })),
5021        // We should never receive a frame while reading the message, according
5022        // to the `tungstenite` maintainers:
5023        //
5024        // > It cannot occur when you read messages from the WebSocket, but it
5025        // > can be used when you want to send the raw frames (e.g. you want to
5026        // > send the frames to the WebSocket without composing the full message first).
5027        // >
5028        // > — https://github.com/snapview/tungstenite-rs/issues/268
5029        TungsteniteMessage::Frame(_) => {
5030            bail!("received an unexpected frame while reading the message")
5031        }
5032    };
5033
5034    Ok(message)
5035}
5036
5037fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5038    match message {
5039        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5040        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5041        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5042        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5043        AxumMessage::Close(frame) => {
5044            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5045                code: frame.code.into(),
5046                reason: frame.reason,
5047            }))
5048        }
5049    }
5050}
5051
5052fn notify_membership_updated(
5053    connection_pool: &mut ConnectionPool,
5054    result: MembershipUpdated,
5055    user_id: UserId,
5056    peer: &Peer,
5057) {
5058    for membership in &result.new_channels.channel_memberships {
5059        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5060    }
5061    for channel_id in &result.removed_channels {
5062        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5063    }
5064
5065    let user_channels_update = proto::UpdateUserChannels {
5066        channel_memberships: result
5067            .new_channels
5068            .channel_memberships
5069            .iter()
5070            .map(|cm| proto::ChannelMembership {
5071                channel_id: cm.channel_id.to_proto(),
5072                role: cm.role.into(),
5073            })
5074            .collect(),
5075        ..Default::default()
5076    };
5077
5078    let mut update = build_channels_update(result.new_channels);
5079    update.delete_channels = result
5080        .removed_channels
5081        .into_iter()
5082        .map(|id| id.to_proto())
5083        .collect();
5084    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5085
5086    for connection_id in connection_pool.user_connection_ids(user_id) {
5087        peer.send(connection_id, user_channels_update.clone())
5088            .trace_err();
5089        peer.send(connection_id, update.clone()).trace_err();
5090    }
5091}
5092
5093fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5094    proto::UpdateUserChannels {
5095        channel_memberships: channels
5096            .channel_memberships
5097            .iter()
5098            .map(|m| proto::ChannelMembership {
5099                channel_id: m.channel_id.to_proto(),
5100                role: m.role.into(),
5101            })
5102            .collect(),
5103        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5104        observed_channel_message_id: channels.observed_channel_messages.clone(),
5105    }
5106}
5107
5108fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5109    let mut update = proto::UpdateChannels::default();
5110
5111    for channel in channels.channels {
5112        update.channels.push(channel.to_proto());
5113    }
5114
5115    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5116    update.latest_channel_message_ids = channels.latest_channel_messages;
5117
5118    for (channel_id, participants) in channels.channel_participants {
5119        update
5120            .channel_participants
5121            .push(proto::ChannelParticipants {
5122                channel_id: channel_id.to_proto(),
5123                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5124            });
5125    }
5126
5127    for channel in channels.invited_channels {
5128        update.channel_invitations.push(channel.to_proto());
5129    }
5130
5131    update.hosted_projects = channels.hosted_projects;
5132    update
5133}
5134
5135fn build_initial_contacts_update(
5136    contacts: Vec<db::Contact>,
5137    pool: &ConnectionPool,
5138) -> proto::UpdateContacts {
5139    let mut update = proto::UpdateContacts::default();
5140
5141    for contact in contacts {
5142        match contact {
5143            db::Contact::Accepted { user_id, busy } => {
5144                update.contacts.push(contact_for_user(user_id, busy, &pool));
5145            }
5146            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5147            db::Contact::Incoming { user_id } => {
5148                update
5149                    .incoming_requests
5150                    .push(proto::IncomingContactRequest {
5151                        requester_id: user_id.to_proto(),
5152                    })
5153            }
5154        }
5155    }
5156
5157    update
5158}
5159
5160fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5161    proto::Contact {
5162        user_id: user_id.to_proto(),
5163        online: pool.is_user_online(user_id),
5164        busy,
5165    }
5166}
5167
5168fn room_updated(room: &proto::Room, peer: &Peer) {
5169    broadcast(
5170        None,
5171        room.participants
5172            .iter()
5173            .filter_map(|participant| Some(participant.peer_id?.into())),
5174        |peer_id| {
5175            peer.send(
5176                peer_id,
5177                proto::RoomUpdated {
5178                    room: Some(room.clone()),
5179                },
5180            )
5181        },
5182    );
5183}
5184
5185fn channel_updated(
5186    channel: &db::channel::Model,
5187    room: &proto::Room,
5188    peer: &Peer,
5189    pool: &ConnectionPool,
5190) {
5191    let participants = room
5192        .participants
5193        .iter()
5194        .map(|p| p.user_id)
5195        .collect::<Vec<_>>();
5196
5197    broadcast(
5198        None,
5199        pool.channel_connection_ids(channel.root_id())
5200            .filter_map(|(channel_id, role)| {
5201                role.can_see_channel(channel.visibility).then(|| channel_id)
5202            }),
5203        |peer_id| {
5204            peer.send(
5205                peer_id,
5206                proto::UpdateChannels {
5207                    channel_participants: vec![proto::ChannelParticipants {
5208                        channel_id: channel.id.to_proto(),
5209                        participant_user_ids: participants.clone(),
5210                    }],
5211                    ..Default::default()
5212                },
5213            )
5214        },
5215    );
5216}
5217
5218async fn send_dev_server_projects_update(
5219    user_id: UserId,
5220    mut status: proto::DevServerProjectsUpdate,
5221    session: &Session,
5222) {
5223    let pool = session.connection_pool().await;
5224    for dev_server in &mut status.dev_servers {
5225        dev_server.status =
5226            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5227    }
5228    let connections = pool.user_connection_ids(user_id);
5229    for connection_id in connections {
5230        session.peer.send(connection_id, status.clone()).trace_err();
5231    }
5232}
5233
5234async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5235    let db = session.db().await;
5236
5237    let contacts = db.get_contacts(user_id).await?;
5238    let busy = db.is_user_busy(user_id).await?;
5239
5240    let pool = session.connection_pool().await;
5241    let updated_contact = contact_for_user(user_id, busy, &pool);
5242    for contact in contacts {
5243        if let db::Contact::Accepted {
5244            user_id: contact_user_id,
5245            ..
5246        } = contact
5247        {
5248            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5249                session
5250                    .peer
5251                    .send(
5252                        contact_conn_id,
5253                        proto::UpdateContacts {
5254                            contacts: vec![updated_contact.clone()],
5255                            remove_contacts: Default::default(),
5256                            incoming_requests: Default::default(),
5257                            remove_incoming_requests: Default::default(),
5258                            outgoing_requests: Default::default(),
5259                            remove_outgoing_requests: Default::default(),
5260                        },
5261                    )
5262                    .trace_err();
5263            }
5264        }
5265    }
5266    Ok(())
5267}
5268
5269async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5270    log::info!("lost dev server connection, unsharing projects");
5271    let project_ids = session
5272        .db()
5273        .await
5274        .get_stale_dev_server_projects(session.connection_id)
5275        .await?;
5276
5277    for project_id in project_ids {
5278        // not unshare re-checks the connection ids match, so we get away with no transaction
5279        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5280    }
5281
5282    let user_id = session.dev_server().user_id;
5283    let update = session
5284        .db()
5285        .await
5286        .dev_server_projects_update(user_id)
5287        .await?;
5288
5289    send_dev_server_projects_update(user_id, update, session).await;
5290
5291    Ok(())
5292}
5293
5294async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5295    let mut contacts_to_update = HashSet::default();
5296
5297    let room_id;
5298    let canceled_calls_to_user_ids;
5299    let live_kit_room;
5300    let delete_live_kit_room;
5301    let room;
5302    let channel;
5303
5304    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5305        contacts_to_update.insert(session.user_id());
5306
5307        for project in left_room.left_projects.values() {
5308            project_left(project, session);
5309        }
5310
5311        room_id = RoomId::from_proto(left_room.room.id);
5312        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5313        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5314        delete_live_kit_room = left_room.deleted;
5315        room = mem::take(&mut left_room.room);
5316        channel = mem::take(&mut left_room.channel);
5317
5318        room_updated(&room, &session.peer);
5319    } else {
5320        return Ok(());
5321    }
5322
5323    if let Some(channel) = channel {
5324        channel_updated(
5325            &channel,
5326            &room,
5327            &session.peer,
5328            &*session.connection_pool().await,
5329        );
5330    }
5331
5332    {
5333        let pool = session.connection_pool().await;
5334        for canceled_user_id in canceled_calls_to_user_ids {
5335            for connection_id in pool.user_connection_ids(canceled_user_id) {
5336                session
5337                    .peer
5338                    .send(
5339                        connection_id,
5340                        proto::CallCanceled {
5341                            room_id: room_id.to_proto(),
5342                        },
5343                    )
5344                    .trace_err();
5345            }
5346            contacts_to_update.insert(canceled_user_id);
5347        }
5348    }
5349
5350    for contact_user_id in contacts_to_update {
5351        update_user_contacts(contact_user_id, &session).await?;
5352    }
5353
5354    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5355        live_kit
5356            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5357            .await
5358            .trace_err();
5359
5360        if delete_live_kit_room {
5361            live_kit.delete_room(live_kit_room).await.trace_err();
5362        }
5363    }
5364
5365    Ok(())
5366}
5367
5368async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5369    let left_channel_buffers = session
5370        .db()
5371        .await
5372        .leave_channel_buffers(session.connection_id)
5373        .await?;
5374
5375    for left_buffer in left_channel_buffers {
5376        channel_buffer_updated(
5377            session.connection_id,
5378            left_buffer.connections,
5379            &proto::UpdateChannelBufferCollaborators {
5380                channel_id: left_buffer.channel_id.to_proto(),
5381                collaborators: left_buffer.collaborators,
5382            },
5383            &session.peer,
5384        );
5385    }
5386
5387    Ok(())
5388}
5389
5390fn project_left(project: &db::LeftProject, session: &UserSession) {
5391    for connection_id in &project.connection_ids {
5392        if project.should_unshare {
5393            session
5394                .peer
5395                .send(
5396                    *connection_id,
5397                    proto::UnshareProject {
5398                        project_id: project.id.to_proto(),
5399                    },
5400                )
5401                .trace_err();
5402        } else {
5403            session
5404                .peer
5405                .send(
5406                    *connection_id,
5407                    proto::RemoveProjectCollaborator {
5408                        project_id: project.id.to_proto(),
5409                        peer_id: Some(session.connection_id.into()),
5410                    },
5411                )
5412                .trace_err();
5413        }
5414    }
5415}
5416
5417pub trait ResultExt {
5418    type Ok;
5419
5420    fn trace_err(self) -> Option<Self::Ok>;
5421}
5422
5423impl<T, E> ResultExt for Result<T, E>
5424where
5425    E: std::fmt::Debug,
5426{
5427    type Ok = T;
5428
5429    #[track_caller]
5430    fn trace_err(self) -> Option<T> {
5431        match self {
5432            Ok(value) => Some(value),
5433            Err(error) => {
5434                tracing::error!("{:?}", error);
5435                None
5436            }
5437        }
5438    }
5439}