rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use isahc_http_client::IsahcHttpClient;
  40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 195        if self.is_staff() {
 196            return Ok(proto::Plan::ZedPro);
 197        }
 198
 199        let Some(user_id) = self.user_id() else {
 200            return Ok(proto::Plan::Free);
 201        };
 202
 203        if db.has_active_billing_subscription(user_id).await? {
 204            Ok(proto::Plan::ZedPro)
 205        } else {
 206            Ok(proto::Plan::Free)
 207        }
 208    }
 209
 210    fn dev_server_id(&self) -> Option<DevServerId> {
 211        match &self.principal {
 212            Principal::User(_) | Principal::Impersonated { .. } => None,
 213            Principal::DevServer(dev_server) => Some(dev_server.id),
 214        }
 215    }
 216
 217    fn principal_id(&self) -> PrincipalId {
 218        match &self.principal {
 219            Principal::User(user) => PrincipalId::UserId(user.id),
 220            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 221            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 222        }
 223    }
 224}
 225
 226impl Debug for Session {
 227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 228        let mut result = f.debug_struct("Session");
 229        match &self.principal {
 230            Principal::User(user) => {
 231                result.field("user", &user.github_login);
 232            }
 233            Principal::Impersonated { user, admin } => {
 234                result.field("user", &user.github_login);
 235                result.field("impersonator", &admin.github_login);
 236            }
 237            Principal::DevServer(dev_server) => {
 238                result.field("dev_server", &dev_server.id);
 239            }
 240        }
 241        result.field("connection_id", &self.connection_id).finish()
 242    }
 243}
 244
 245struct UserSession(Session);
 246
 247impl UserSession {
 248    pub fn new(s: Session) -> Option<Self> {
 249        s.user_id().map(|_| UserSession(s))
 250    }
 251    pub fn user_id(&self) -> UserId {
 252        self.0.user_id().unwrap()
 253    }
 254
 255    pub fn email(&self) -> Option<String> {
 256        match &self.0.principal {
 257            Principal::User(user) => user.email_address.clone(),
 258            Principal::Impersonated { user, .. } => user.email_address.clone(),
 259            Principal::DevServer(..) => None,
 260        }
 261    }
 262}
 263
 264impl Deref for UserSession {
 265    type Target = Session;
 266
 267    fn deref(&self) -> &Self::Target {
 268        &self.0
 269    }
 270}
 271impl DerefMut for UserSession {
 272    fn deref_mut(&mut self) -> &mut Self::Target {
 273        &mut self.0
 274    }
 275}
 276
 277struct DevServerSession(Session);
 278
 279impl DevServerSession {
 280    pub fn new(s: Session) -> Option<Self> {
 281        s.dev_server_id().map(|_| DevServerSession(s))
 282    }
 283    pub fn dev_server_id(&self) -> DevServerId {
 284        self.0.dev_server_id().unwrap()
 285    }
 286
 287    fn dev_server(&self) -> &dev_server::Model {
 288        match &self.0.principal {
 289            Principal::DevServer(dev_server) => dev_server,
 290            _ => unreachable!(),
 291        }
 292    }
 293}
 294
 295impl Deref for DevServerSession {
 296    type Target = Session;
 297
 298    fn deref(&self) -> &Self::Target {
 299        &self.0
 300    }
 301}
 302impl DerefMut for DevServerSession {
 303    fn deref_mut(&mut self) -> &mut Self::Target {
 304        &mut self.0
 305    }
 306}
 307
 308fn user_handler<M: RequestMessage, Fut>(
 309    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 311where
 312    Fut: Send + Future<Output = Result<()>>,
 313{
 314    let handler = Arc::new(handler);
 315    move |message, response, session| {
 316        let handler = handler.clone();
 317        Box::pin(async move {
 318            if let Some(user_session) = session.for_user() {
 319                Ok(handler(message, response, user_session).await?)
 320            } else {
 321                Err(Error::Internal(anyhow!(
 322                    "must be a user to call {}",
 323                    M::NAME
 324                )))
 325            }
 326        })
 327    }
 328}
 329
 330fn dev_server_handler<M: RequestMessage, Fut>(
 331    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 333where
 334    Fut: Send + Future<Output = Result<()>>,
 335{
 336    let handler = Arc::new(handler);
 337    move |message, response, session| {
 338        let handler = handler.clone();
 339        Box::pin(async move {
 340            if let Some(dev_server_session) = session.for_dev_server() {
 341                Ok(handler(message, response, dev_server_session).await?)
 342            } else {
 343                Err(Error::Internal(anyhow!(
 344                    "must be a dev server to call {}",
 345                    M::NAME
 346                )))
 347            }
 348        })
 349    }
 350}
 351
 352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 353    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 355where
 356    InnertRetFut: Send + Future<Output = Result<()>>,
 357{
 358    let handler = Arc::new(handler);
 359    move |message, session| {
 360        let handler = handler.clone();
 361        Box::pin(async move {
 362            if let Some(user_session) = session.for_user() {
 363                Ok(handler(message, user_session).await?)
 364            } else {
 365                Err(Error::Internal(anyhow!(
 366                    "must be a user to call {}",
 367                    M::NAME
 368                )))
 369            }
 370        })
 371    }
 372}
 373
 374struct DbHandle(Arc<Database>);
 375
 376impl Deref for DbHandle {
 377    type Target = Database;
 378
 379    fn deref(&self) -> &Self::Target {
 380        self.0.as_ref()
 381    }
 382}
 383
 384pub struct Server {
 385    id: parking_lot::Mutex<ServerId>,
 386    peer: Arc<Peer>,
 387    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 388    app_state: Arc<AppState>,
 389    handlers: HashMap<TypeId, MessageHandler>,
 390    teardown: watch::Sender<bool>,
 391}
 392
 393pub(crate) struct ConnectionPoolGuard<'a> {
 394    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 395    _not_send: PhantomData<Rc<()>>,
 396}
 397
 398#[derive(Serialize)]
 399pub struct ServerSnapshot<'a> {
 400    peer: &'a Peer,
 401    #[serde(serialize_with = "serialize_deref")]
 402    connection_pool: ConnectionPoolGuard<'a>,
 403}
 404
 405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 406where
 407    S: Serializer,
 408    T: Deref<Target = U>,
 409    U: Serialize,
 410{
 411    Serialize::serialize(value.deref(), serializer)
 412}
 413
 414impl Server {
 415    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 416        let mut server = Self {
 417            id: parking_lot::Mutex::new(id),
 418            peer: Peer::new(id.0 as u32),
 419            app_state: app_state.clone(),
 420            connection_pool: Default::default(),
 421            handlers: Default::default(),
 422            teardown: watch::channel(false).0,
 423        };
 424
 425        server
 426            .add_request_handler(ping)
 427            .add_request_handler(user_handler(create_room))
 428            .add_request_handler(user_handler(join_room))
 429            .add_request_handler(user_handler(rejoin_room))
 430            .add_request_handler(user_handler(leave_room))
 431            .add_request_handler(user_handler(set_room_participant_role))
 432            .add_request_handler(user_handler(call))
 433            .add_request_handler(user_handler(cancel_call))
 434            .add_message_handler(user_message_handler(decline_call))
 435            .add_request_handler(user_handler(update_participant_location))
 436            .add_request_handler(user_handler(share_project))
 437            .add_message_handler(unshare_project)
 438            .add_request_handler(user_handler(join_project))
 439            .add_request_handler(user_handler(join_hosted_project))
 440            .add_request_handler(user_handler(rejoin_dev_server_projects))
 441            .add_request_handler(user_handler(create_dev_server_project))
 442            .add_request_handler(user_handler(update_dev_server_project))
 443            .add_request_handler(user_handler(delete_dev_server_project))
 444            .add_request_handler(user_handler(create_dev_server))
 445            .add_request_handler(user_handler(regenerate_dev_server_token))
 446            .add_request_handler(user_handler(rename_dev_server))
 447            .add_request_handler(user_handler(delete_dev_server))
 448            .add_request_handler(user_handler(list_remote_directory))
 449            .add_request_handler(dev_server_handler(share_dev_server_project))
 450            .add_request_handler(dev_server_handler(shutdown_dev_server))
 451            .add_request_handler(dev_server_handler(reconnect_dev_server))
 452            .add_message_handler(user_message_handler(leave_project))
 453            .add_request_handler(update_project)
 454            .add_request_handler(update_worktree)
 455            .add_message_handler(start_language_server)
 456            .add_message_handler(update_language_server)
 457            .add_message_handler(update_diagnostic_summary)
 458            .add_message_handler(update_worktree_settings)
 459            .add_request_handler(user_handler(
 460                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_project_request_for_owner::<proto::TaskTemplates>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_read_only_project_request::<proto::GetHover>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_read_only_project_request::<proto::GetDefinition>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_read_only_project_request::<proto::GetTypeDefinition>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::GetReferences>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_read_only_project_request::<proto::SearchProject>,
 479            ))
 480            .add_request_handler(user_handler(forward_find_search_candidates_request))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetProjectSymbols>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferById>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::InlayHints>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::ResolveInlayHint>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_read_only_project_request::<proto::OpenBufferByPath>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCompletions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::OpenNewBuffer>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::GetCodeActions>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::ApplyCodeAction>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::PrepareRename>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::PerformRename>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::ReloadBuffers>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::FormatBuffers>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::CreateProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::RenameProjectEntry>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::CopyProjectEntry>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::OnTypeFormatting>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::SaveBuffer>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::BlameBuffer>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::MultiLspQuery>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_mutating_project_request::<proto::RestartLanguageServers>,
 564            ))
 565            .add_request_handler(user_handler(
 566                forward_mutating_project_request::<proto::LinkedEditingRange>,
 567            ))
 568            .add_message_handler(create_buffer_for_peer)
 569            .add_request_handler(update_buffer)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 573            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 574            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 575            .add_request_handler(get_users)
 576            .add_request_handler(user_handler(fuzzy_search_users))
 577            .add_request_handler(user_handler(request_contact))
 578            .add_request_handler(user_handler(remove_contact))
 579            .add_request_handler(user_handler(respond_to_contact_request))
 580            .add_message_handler(subscribe_to_channels)
 581            .add_request_handler(user_handler(create_channel))
 582            .add_request_handler(user_handler(delete_channel))
 583            .add_request_handler(user_handler(invite_channel_member))
 584            .add_request_handler(user_handler(remove_channel_member))
 585            .add_request_handler(user_handler(set_channel_member_role))
 586            .add_request_handler(user_handler(set_channel_visibility))
 587            .add_request_handler(user_handler(rename_channel))
 588            .add_request_handler(user_handler(join_channel_buffer))
 589            .add_request_handler(user_handler(leave_channel_buffer))
 590            .add_message_handler(user_message_handler(update_channel_buffer))
 591            .add_request_handler(user_handler(rejoin_channel_buffers))
 592            .add_request_handler(user_handler(get_channel_members))
 593            .add_request_handler(user_handler(respond_to_channel_invite))
 594            .add_request_handler(user_handler(join_channel))
 595            .add_request_handler(user_handler(join_channel_chat))
 596            .add_message_handler(user_message_handler(leave_channel_chat))
 597            .add_request_handler(user_handler(send_channel_message))
 598            .add_request_handler(user_handler(remove_channel_message))
 599            .add_request_handler(user_handler(update_channel_message))
 600            .add_request_handler(user_handler(get_channel_messages))
 601            .add_request_handler(user_handler(get_channel_messages_by_id))
 602            .add_request_handler(user_handler(get_notifications))
 603            .add_request_handler(user_handler(mark_notification_as_read))
 604            .add_request_handler(user_handler(move_channel))
 605            .add_request_handler(user_handler(follow))
 606            .add_message_handler(user_message_handler(unfollow))
 607            .add_message_handler(user_message_handler(update_followers))
 608            .add_request_handler(user_handler(get_private_user_info))
 609            .add_request_handler(user_handler(get_llm_api_token))
 610            .add_request_handler(user_handler(accept_terms_of_service))
 611            .add_message_handler(user_message_handler(acknowledge_channel_message))
 612            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 613            .add_request_handler(user_handler(get_supermaven_api_key))
 614            .add_request_handler(user_handler(
 615                forward_mutating_project_request::<proto::OpenContext>,
 616            ))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::CreateContext>,
 619            ))
 620            .add_request_handler(user_handler(
 621                forward_mutating_project_request::<proto::SynchronizeContexts>,
 622            ))
 623            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 624            .add_message_handler(update_context)
 625            .add_request_handler({
 626                let app_state = app_state.clone();
 627                move |request, response, session| {
 628                    let app_state = app_state.clone();
 629                    async move {
 630                        count_language_model_tokens(request, response, session, &app_state.config)
 631                            .await
 632                    }
 633                }
 634            })
 635            .add_request_handler({
 636                user_handler(move |request, response, session| {
 637                    get_cached_embeddings(request, response, session)
 638                })
 639            })
 640            .add_request_handler({
 641                let app_state = app_state.clone();
 642                user_handler(move |request, response, session| {
 643                    compute_embeddings(
 644                        request,
 645                        response,
 646                        session,
 647                        app_state.config.openai_api_key.clone(),
 648                    )
 649                })
 650            });
 651
 652        Arc::new(server)
 653    }
 654
 655    pub async fn start(&self) -> Result<()> {
 656        let server_id = *self.id.lock();
 657        let app_state = self.app_state.clone();
 658        let peer = self.peer.clone();
 659        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 660        let pool = self.connection_pool.clone();
 661        let live_kit_client = self.app_state.live_kit_client.clone();
 662
 663        let span = info_span!("start server");
 664        self.app_state.executor.spawn_detached(
 665            async move {
 666                tracing::info!("waiting for cleanup timeout");
 667                timeout.await;
 668                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 669                if let Some((room_ids, channel_ids)) = app_state
 670                    .db
 671                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 672                    .await
 673                    .trace_err()
 674                {
 675                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 676                    tracing::info!(
 677                        stale_channel_buffer_count = channel_ids.len(),
 678                        "retrieved stale channel buffers"
 679                    );
 680
 681                    for channel_id in channel_ids {
 682                        if let Some(refreshed_channel_buffer) = app_state
 683                            .db
 684                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 685                            .await
 686                            .trace_err()
 687                        {
 688                            for connection_id in refreshed_channel_buffer.connection_ids {
 689                                peer.send(
 690                                    connection_id,
 691                                    proto::UpdateChannelBufferCollaborators {
 692                                        channel_id: channel_id.to_proto(),
 693                                        collaborators: refreshed_channel_buffer
 694                                            .collaborators
 695                                            .clone(),
 696                                    },
 697                                )
 698                                .trace_err();
 699                            }
 700                        }
 701                    }
 702
 703                    for room_id in room_ids {
 704                        let mut contacts_to_update = HashSet::default();
 705                        let mut canceled_calls_to_user_ids = Vec::new();
 706                        let mut live_kit_room = String::new();
 707                        let mut delete_live_kit_room = false;
 708
 709                        if let Some(mut refreshed_room) = app_state
 710                            .db
 711                            .clear_stale_room_participants(room_id, server_id)
 712                            .await
 713                            .trace_err()
 714                        {
 715                            tracing::info!(
 716                                room_id = room_id.0,
 717                                new_participant_count = refreshed_room.room.participants.len(),
 718                                "refreshed room"
 719                            );
 720                            room_updated(&refreshed_room.room, &peer);
 721                            if let Some(channel) = refreshed_room.channel.as_ref() {
 722                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 723                            }
 724                            contacts_to_update
 725                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 726                            contacts_to_update
 727                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 728                            canceled_calls_to_user_ids =
 729                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 730                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 731                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 732                        }
 733
 734                        {
 735                            let pool = pool.lock();
 736                            for canceled_user_id in canceled_calls_to_user_ids {
 737                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 738                                    peer.send(
 739                                        connection_id,
 740                                        proto::CallCanceled {
 741                                            room_id: room_id.to_proto(),
 742                                        },
 743                                    )
 744                                    .trace_err();
 745                                }
 746                            }
 747                        }
 748
 749                        for user_id in contacts_to_update {
 750                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 751                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 752                            if let Some((busy, contacts)) = busy.zip(contacts) {
 753                                let pool = pool.lock();
 754                                let updated_contact = contact_for_user(user_id, busy, &pool);
 755                                for contact in contacts {
 756                                    if let db::Contact::Accepted {
 757                                        user_id: contact_user_id,
 758                                        ..
 759                                    } = contact
 760                                    {
 761                                        for contact_conn_id in
 762                                            pool.user_connection_ids(contact_user_id)
 763                                        {
 764                                            peer.send(
 765                                                contact_conn_id,
 766                                                proto::UpdateContacts {
 767                                                    contacts: vec![updated_contact.clone()],
 768                                                    remove_contacts: Default::default(),
 769                                                    incoming_requests: Default::default(),
 770                                                    remove_incoming_requests: Default::default(),
 771                                                    outgoing_requests: Default::default(),
 772                                                    remove_outgoing_requests: Default::default(),
 773                                                },
 774                                            )
 775                                            .trace_err();
 776                                        }
 777                                    }
 778                                }
 779                            }
 780                        }
 781
 782                        if let Some(live_kit) = live_kit_client.as_ref() {
 783                            if delete_live_kit_room {
 784                                live_kit.delete_room(live_kit_room).await.trace_err();
 785                            }
 786                        }
 787                    }
 788                }
 789
 790                app_state
 791                    .db
 792                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 793                    .await
 794                    .trace_err();
 795            }
 796            .instrument(span),
 797        );
 798        Ok(())
 799    }
 800
 801    pub fn teardown(&self) {
 802        self.peer.teardown();
 803        self.connection_pool.lock().reset();
 804        let _ = self.teardown.send(true);
 805    }
 806
 807    #[cfg(test)]
 808    pub fn reset(&self, id: ServerId) {
 809        self.teardown();
 810        *self.id.lock() = id;
 811        self.peer.reset(id.0 as u32);
 812        let _ = self.teardown.send(false);
 813    }
 814
 815    #[cfg(test)]
 816    pub fn id(&self) -> ServerId {
 817        *self.id.lock()
 818    }
 819
 820    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 821    where
 822        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 823        Fut: 'static + Send + Future<Output = Result<()>>,
 824        M: EnvelopedMessage,
 825    {
 826        let prev_handler = self.handlers.insert(
 827            TypeId::of::<M>(),
 828            Box::new(move |envelope, session| {
 829                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 830                let received_at = envelope.received_at;
 831                tracing::info!("message received");
 832                let start_time = Instant::now();
 833                let future = (handler)(*envelope, session);
 834                async move {
 835                    let result = future.await;
 836                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 837                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 838                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 839                    let payload_type = M::NAME;
 840
 841                    match result {
 842                        Err(error) => {
 843                            tracing::error!(
 844                                ?error,
 845                                total_duration_ms,
 846                                processing_duration_ms,
 847                                queue_duration_ms,
 848                                payload_type,
 849                                "error handling message"
 850                            )
 851                        }
 852                        Ok(()) => tracing::info!(
 853                            total_duration_ms,
 854                            processing_duration_ms,
 855                            queue_duration_ms,
 856                            "finished handling message"
 857                        ),
 858                    }
 859                }
 860                .boxed()
 861            }),
 862        );
 863        if prev_handler.is_some() {
 864            panic!("registered a handler for the same message twice");
 865        }
 866        self
 867    }
 868
 869    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 870    where
 871        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 872        Fut: 'static + Send + Future<Output = Result<()>>,
 873        M: EnvelopedMessage,
 874    {
 875        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 876        self
 877    }
 878
 879    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 880    where
 881        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 882        Fut: Send + Future<Output = Result<()>>,
 883        M: RequestMessage,
 884    {
 885        let handler = Arc::new(handler);
 886        self.add_handler(move |envelope, session| {
 887            let receipt = envelope.receipt();
 888            let handler = handler.clone();
 889            async move {
 890                let peer = session.peer.clone();
 891                let responded = Arc::new(AtomicBool::default());
 892                let response = Response {
 893                    peer: peer.clone(),
 894                    responded: responded.clone(),
 895                    receipt,
 896                };
 897                match (handler)(envelope.payload, response, session).await {
 898                    Ok(()) => {
 899                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 900                            Ok(())
 901                        } else {
 902                            Err(anyhow!("handler did not send a response"))?
 903                        }
 904                    }
 905                    Err(error) => {
 906                        let proto_err = match &error {
 907                            Error::Internal(err) => err.to_proto(),
 908                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 909                        };
 910                        peer.respond_with_error(receipt, proto_err)?;
 911                        Err(error)
 912                    }
 913                }
 914            }
 915        })
 916    }
 917
 918    #[allow(clippy::too_many_arguments)]
 919    pub fn handle_connection(
 920        self: &Arc<Self>,
 921        connection: Connection,
 922        address: String,
 923        principal: Principal,
 924        zed_version: ZedVersion,
 925        geoip_country_code: Option<String>,
 926        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 927        executor: Executor,
 928    ) -> impl Future<Output = ()> {
 929        let this = self.clone();
 930        let span = info_span!("handle connection", %address,
 931            connection_id=field::Empty,
 932            user_id=field::Empty,
 933            login=field::Empty,
 934            impersonator=field::Empty,
 935            dev_server_id=field::Empty,
 936            geoip_country_code=field::Empty
 937        );
 938        principal.update_span(&span);
 939        if let Some(country_code) = geoip_country_code.as_ref() {
 940            span.record("geoip_country_code", country_code);
 941        }
 942
 943        let mut teardown = self.teardown.subscribe();
 944        async move {
 945            if *teardown.borrow() {
 946                tracing::error!("server is tearing down");
 947                return
 948            }
 949            let (connection_id, handle_io, mut incoming_rx) = this
 950                .peer
 951                .add_connection(connection, {
 952                    let executor = executor.clone();
 953                    move |duration| executor.sleep(duration)
 954                });
 955            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 956
 957            tracing::info!("connection opened");
 958
 959
 960            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 961            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 962                Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
 963                Err(error) => {
 964                    tracing::error!(?error, "failed to create HTTP client");
 965                    return;
 966                }
 967            };
 968
 969            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 970                    supermaven_admin_api_key.to_string(),
 971                    http_client.clone(),
 972                )));
 973
 974            let session = Session {
 975                principal: principal.clone(),
 976                connection_id,
 977                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 978                peer: this.peer.clone(),
 979                connection_pool: this.connection_pool.clone(),
 980                app_state: this.app_state.clone(),
 981                http_client,
 982                geoip_country_code,
 983                _executor: executor.clone(),
 984                supermaven_client,
 985            };
 986
 987            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 988                tracing::error!(?error, "failed to send initial client update");
 989                return;
 990            }
 991
 992            let handle_io = handle_io.fuse();
 993            futures::pin_mut!(handle_io);
 994
 995            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 996            // This prevents deadlocks when e.g., client A performs a request to client B and
 997            // client B performs a request to client A. If both clients stop processing further
 998            // messages until their respective request completes, they won't have a chance to
 999            // respond to the other client's request and cause a deadlock.
1000            //
1001            // This arrangement ensures we will attempt to process earlier messages first, but fall
1002            // back to processing messages arrived later in the spirit of making progress.
1003            let mut foreground_message_handlers = FuturesUnordered::new();
1004            let concurrent_handlers = Arc::new(Semaphore::new(256));
1005            loop {
1006                let next_message = async {
1007                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1008                    let message = incoming_rx.next().await;
1009                    (permit, message)
1010                }.fuse();
1011                futures::pin_mut!(next_message);
1012                futures::select_biased! {
1013                    _ = teardown.changed().fuse() => return,
1014                    result = handle_io => {
1015                        if let Err(error) = result {
1016                            tracing::error!(?error, "error handling I/O");
1017                        }
1018                        break;
1019                    }
1020                    _ = foreground_message_handlers.next() => {}
1021                    next_message = next_message => {
1022                        let (permit, message) = next_message;
1023                        if let Some(message) = message {
1024                            let type_name = message.payload_type_name();
1025                            // note: we copy all the fields from the parent span so we can query them in the logs.
1026                            // (https://github.com/tokio-rs/tracing/issues/2670).
1027                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1028                                user_id=field::Empty,
1029                                login=field::Empty,
1030                                impersonator=field::Empty,
1031                                dev_server_id=field::Empty
1032                            );
1033                            principal.update_span(&span);
1034                            let span_enter = span.enter();
1035                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1036                                let is_background = message.is_background();
1037                                let handle_message = (handler)(message, session.clone());
1038                                drop(span_enter);
1039
1040                                let handle_message = async move {
1041                                    handle_message.await;
1042                                    drop(permit);
1043                                }.instrument(span);
1044                                if is_background {
1045                                    executor.spawn_detached(handle_message);
1046                                } else {
1047                                    foreground_message_handlers.push(handle_message);
1048                                }
1049                            } else {
1050                                tracing::error!("no message handler");
1051                            }
1052                        } else {
1053                            tracing::info!("connection closed");
1054                            break;
1055                        }
1056                    }
1057                }
1058            }
1059
1060            drop(foreground_message_handlers);
1061            tracing::info!("signing out");
1062            if let Err(error) = connection_lost(session, teardown, executor).await {
1063                tracing::error!(?error, "error signing out");
1064            }
1065
1066        }.instrument(span)
1067    }
1068
1069    async fn send_initial_client_update(
1070        &self,
1071        connection_id: ConnectionId,
1072        principal: &Principal,
1073        zed_version: ZedVersion,
1074        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1075        session: &Session,
1076    ) -> Result<()> {
1077        self.peer.send(
1078            connection_id,
1079            proto::Hello {
1080                peer_id: Some(connection_id.into()),
1081            },
1082        )?;
1083        tracing::info!("sent hello message");
1084        if let Some(send_connection_id) = send_connection_id.take() {
1085            let _ = send_connection_id.send(connection_id);
1086        }
1087
1088        match principal {
1089            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1090                if !user.connected_once {
1091                    self.peer.send(connection_id, proto::ShowContacts {})?;
1092                    self.app_state
1093                        .db
1094                        .set_user_connected_once(user.id, true)
1095                        .await?;
1096                }
1097
1098                update_user_plan(user.id, session).await?;
1099
1100                let (contacts, dev_server_projects) = future::try_join(
1101                    self.app_state.db.get_contacts(user.id),
1102                    self.app_state.db.dev_server_projects_update(user.id),
1103                )
1104                .await?;
1105
1106                {
1107                    let mut pool = self.connection_pool.lock();
1108                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1109                    self.peer.send(
1110                        connection_id,
1111                        build_initial_contacts_update(contacts, &pool),
1112                    )?;
1113                }
1114
1115                if should_auto_subscribe_to_channels(zed_version) {
1116                    subscribe_user_to_channels(user.id, session).await?;
1117                }
1118
1119                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1120
1121                if let Some(incoming_call) =
1122                    self.app_state.db.incoming_call_for_user(user.id).await?
1123                {
1124                    self.peer.send(connection_id, incoming_call)?;
1125                }
1126
1127                update_user_contacts(user.id, session).await?;
1128            }
1129            Principal::DevServer(dev_server) => {
1130                {
1131                    let mut pool = self.connection_pool.lock();
1132                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1133                    {
1134                        self.peer.send(
1135                            stale_connection_id,
1136                            proto::ShutdownDevServer {
1137                                reason: Some(
1138                                    "another dev server connected with the same token".to_string(),
1139                                ),
1140                            },
1141                        )?;
1142                        pool.remove_connection(stale_connection_id)?;
1143                    };
1144                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145                }
1146
1147                let projects = self
1148                    .app_state
1149                    .db
1150                    .get_projects_for_dev_server(dev_server.id)
1151                    .await?;
1152                self.peer
1153                    .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155                let status = self
1156                    .app_state
1157                    .db
1158                    .dev_server_projects_update(dev_server.user_id)
1159                    .await?;
1160                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1161            }
1162        }
1163
1164        Ok(())
1165    }
1166
1167    pub async fn invite_code_redeemed(
1168        self: &Arc<Self>,
1169        inviter_id: UserId,
1170        invitee_id: UserId,
1171    ) -> Result<()> {
1172        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173            if let Some(code) = &user.invite_code {
1174                let pool = self.connection_pool.lock();
1175                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176                for connection_id in pool.user_connection_ids(inviter_id) {
1177                    self.peer.send(
1178                        connection_id,
1179                        proto::UpdateContacts {
1180                            contacts: vec![invitee_contact.clone()],
1181                            ..Default::default()
1182                        },
1183                    )?;
1184                    self.peer.send(
1185                        connection_id,
1186                        proto::UpdateInviteInfo {
1187                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188                            count: user.invite_count as u32,
1189                        },
1190                    )?;
1191                }
1192            }
1193        }
1194        Ok(())
1195    }
1196
1197    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199            if let Some(invite_code) = &user.invite_code {
1200                let pool = self.connection_pool.lock();
1201                for connection_id in pool.user_connection_ids(user_id) {
1202                    self.peer.send(
1203                        connection_id,
1204                        proto::UpdateInviteInfo {
1205                            url: format!(
1206                                "{}{}",
1207                                self.app_state.config.invite_link_prefix, invite_code
1208                            ),
1209                            count: user.invite_count as u32,
1210                        },
1211                    )?;
1212                }
1213            }
1214        }
1215        Ok(())
1216    }
1217
1218    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219        ServerSnapshot {
1220            connection_pool: ConnectionPoolGuard {
1221                guard: self.connection_pool.lock(),
1222                _not_send: PhantomData,
1223            },
1224            peer: &self.peer,
1225        }
1226    }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230    type Target = ConnectionPool;
1231
1232    fn deref(&self) -> &Self::Target {
1233        &self.guard
1234    }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238    fn deref_mut(&mut self) -> &mut Self::Target {
1239        &mut self.guard
1240    }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244    fn drop(&mut self) {
1245        #[cfg(test)]
1246        self.check_invariants();
1247    }
1248}
1249
1250fn broadcast<F>(
1251    sender_id: Option<ConnectionId>,
1252    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253    mut f: F,
1254) where
1255    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257    for receiver_id in receiver_ids {
1258        if Some(receiver_id) != sender_id {
1259            if let Err(error) = f(receiver_id) {
1260                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261            }
1262        }
1263    }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269    fn name() -> &'static HeaderName {
1270        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272    }
1273
1274    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275    where
1276        Self: Sized,
1277        I: Iterator<Item = &'i axum::http::HeaderValue>,
1278    {
1279        let version = values
1280            .next()
1281            .ok_or_else(axum::headers::Error::invalid)?
1282            .to_str()
1283            .map_err(|_| axum::headers::Error::invalid())?
1284            .parse()
1285            .map_err(|_| axum::headers::Error::invalid())?;
1286        Ok(Self(version))
1287    }
1288
1289    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290        values.extend([self.0.to_string().parse().unwrap()]);
1291    }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296    fn name() -> &'static HeaderName {
1297        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299    }
1300
1301    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302    where
1303        Self: Sized,
1304        I: Iterator<Item = &'i axum::http::HeaderValue>,
1305    {
1306        let version = values
1307            .next()
1308            .ok_or_else(axum::headers::Error::invalid)?
1309            .to_str()
1310            .map_err(|_| axum::headers::Error::invalid())?
1311            .parse()
1312            .map_err(|_| axum::headers::Error::invalid())?;
1313        Ok(Self(version))
1314    }
1315
1316    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317        values.extend([self.0.to_string().parse().unwrap()]);
1318    }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322    Router::new()
1323        .route("/rpc", get(handle_websocket_request))
1324        .layer(
1325            ServiceBuilder::new()
1326                .layer(Extension(server.app_state.clone()))
1327                .layer(middleware::from_fn(auth::validate_header)),
1328        )
1329        .route("/metrics", get(handle_metrics))
1330        .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337    Extension(server): Extension<Arc<Server>>,
1338    Extension(principal): Extension<Principal>,
1339    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1340    ws: WebSocketUpgrade,
1341) -> axum::response::Response {
1342    if protocol_version != rpc::PROTOCOL_VERSION {
1343        return (
1344            StatusCode::UPGRADE_REQUIRED,
1345            "client must be upgraded".to_string(),
1346        )
1347            .into_response();
1348    }
1349
1350    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1351        return (
1352            StatusCode::UPGRADE_REQUIRED,
1353            "no version header found".to_string(),
1354        )
1355            .into_response();
1356    };
1357
1358    if !version.can_collaborate() {
1359        return (
1360            StatusCode::UPGRADE_REQUIRED,
1361            "client must be upgraded".to_string(),
1362        )
1363            .into_response();
1364    }
1365
1366    let socket_address = socket_address.to_string();
1367    ws.on_upgrade(move |socket| {
1368        let socket = socket
1369            .map_ok(to_tungstenite_message)
1370            .err_into()
1371            .with(|message| async move { to_axum_message(message) });
1372        let connection = Connection::new(Box::pin(socket));
1373        async move {
1374            server
1375                .handle_connection(
1376                    connection,
1377                    socket_address,
1378                    principal,
1379                    version,
1380                    country_code_header.map(|header| header.to_string()),
1381                    None,
1382                    Executor::Production,
1383                )
1384                .await;
1385        }
1386    })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391    let connections_metric = CONNECTIONS_METRIC
1392        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394    let connections = server
1395        .connection_pool
1396        .lock()
1397        .connections()
1398        .filter(|connection| !connection.admin)
1399        .count();
1400    connections_metric.set(connections as _);
1401
1402    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404        register_int_gauge!(
1405            "shared_projects",
1406            "number of open projects with one or more guests"
1407        )
1408        .unwrap()
1409    });
1410
1411    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412    shared_projects_metric.set(shared_projects as _);
1413
1414    let encoder = prometheus::TextEncoder::new();
1415    let metric_families = prometheus::gather();
1416    let encoded_metrics = encoder
1417        .encode_to_string(&metric_families)
1418        .map_err(|err| anyhow!("{}", err))?;
1419    Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424    session: Session,
1425    mut teardown: watch::Receiver<bool>,
1426    executor: Executor,
1427) -> Result<()> {
1428    session.peer.disconnect(session.connection_id);
1429    session
1430        .connection_pool()
1431        .await
1432        .remove_connection(session.connection_id)?;
1433
1434    session
1435        .db()
1436        .await
1437        .connection_lost(session.connection_id)
1438        .await
1439        .trace_err();
1440
1441    futures::select_biased! {
1442        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443            match &session.principal {
1444                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445                    let session = session.for_user().unwrap();
1446
1447                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1449                    leave_channel_buffers_for_session(&session)
1450                        .await
1451                        .trace_err();
1452
1453                    if !session
1454                        .connection_pool()
1455                        .await
1456                        .is_user_online(session.user_id())
1457                    {
1458                        let db = session.db().await;
1459                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460                            room_updated(&room, &session.peer);
1461                        }
1462                    }
1463
1464                    update_user_contacts(session.user_id(), &session).await?;
1465                },
1466            Principal::DevServer(_) => {
1467                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468            },
1469        }
1470        },
1471        _ = teardown.changed().fuse() => {}
1472    }
1473
1474    Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479    response.send(proto::Ack {})?;
1480    Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485    _request: proto::CreateRoom,
1486    response: Response<proto::CreateRoom>,
1487    session: UserSession,
1488) -> Result<()> {
1489    let live_kit_room = nanoid::nanoid!(30);
1490
1491    let live_kit_connection_info = util::maybe!(async {
1492        let live_kit = session.app_state.live_kit_client.as_ref();
1493        let live_kit = live_kit?;
1494        let user_id = session.user_id().to_string();
1495
1496        let token = live_kit
1497            .room_token(&live_kit_room, &user_id.to_string())
1498            .trace_err()?;
1499
1500        Some(proto::LiveKitConnectionInfo {
1501            server_url: live_kit.url().into(),
1502            token,
1503            can_publish: true,
1504        })
1505    })
1506    .await;
1507
1508    let room = session
1509        .db()
1510        .await
1511        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512        .await?;
1513
1514    response.send(proto::CreateRoomResponse {
1515        room: Some(room.clone()),
1516        live_kit_connection_info,
1517    })?;
1518
1519    update_user_contacts(session.user_id(), &session).await?;
1520    Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525    request: proto::JoinRoom,
1526    response: Response<proto::JoinRoom>,
1527    session: UserSession,
1528) -> Result<()> {
1529    let room_id = RoomId::from_proto(request.id);
1530
1531    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533    if let Some(channel_id) = channel_id {
1534        return join_channel_internal(channel_id, Box::new(response), session).await;
1535    }
1536
1537    let joined_room = {
1538        let room = session
1539            .db()
1540            .await
1541            .join_room(room_id, session.user_id(), session.connection_id)
1542            .await?;
1543        room_updated(&room.room, &session.peer);
1544        room.into_inner()
1545    };
1546
1547    for connection_id in session
1548        .connection_pool()
1549        .await
1550        .user_connection_ids(session.user_id())
1551    {
1552        session
1553            .peer
1554            .send(
1555                connection_id,
1556                proto::CallCanceled {
1557                    room_id: room_id.to_proto(),
1558                },
1559            )
1560            .trace_err();
1561    }
1562
1563    let live_kit_connection_info =
1564        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1565            live_kit
1566                .room_token(
1567                    &joined_room.room.live_kit_room,
1568                    &session.user_id().to_string(),
1569                )
1570                .trace_err()
1571                .map(|token| proto::LiveKitConnectionInfo {
1572                    server_url: live_kit.url().into(),
1573                    token,
1574                    can_publish: true,
1575                })
1576        } else {
1577            None
1578        };
1579
1580    response.send(proto::JoinRoomResponse {
1581        room: Some(joined_room.room),
1582        channel_id: None,
1583        live_kit_connection_info,
1584    })?;
1585
1586    update_user_contacts(session.user_id(), &session).await?;
1587    Ok(())
1588}
1589
1590/// Rejoin room is used to reconnect to a room after connection errors.
1591async fn rejoin_room(
1592    request: proto::RejoinRoom,
1593    response: Response<proto::RejoinRoom>,
1594    session: UserSession,
1595) -> Result<()> {
1596    let room;
1597    let channel;
1598    {
1599        let mut rejoined_room = session
1600            .db()
1601            .await
1602            .rejoin_room(request, session.user_id(), session.connection_id)
1603            .await?;
1604
1605        response.send(proto::RejoinRoomResponse {
1606            room: Some(rejoined_room.room.clone()),
1607            reshared_projects: rejoined_room
1608                .reshared_projects
1609                .iter()
1610                .map(|project| proto::ResharedProject {
1611                    id: project.id.to_proto(),
1612                    collaborators: project
1613                        .collaborators
1614                        .iter()
1615                        .map(|collaborator| collaborator.to_proto())
1616                        .collect(),
1617                })
1618                .collect(),
1619            rejoined_projects: rejoined_room
1620                .rejoined_projects
1621                .iter()
1622                .map(|rejoined_project| rejoined_project.to_proto())
1623                .collect(),
1624        })?;
1625        room_updated(&rejoined_room.room, &session.peer);
1626
1627        for project in &rejoined_room.reshared_projects {
1628            for collaborator in &project.collaborators {
1629                session
1630                    .peer
1631                    .send(
1632                        collaborator.connection_id,
1633                        proto::UpdateProjectCollaborator {
1634                            project_id: project.id.to_proto(),
1635                            old_peer_id: Some(project.old_connection_id.into()),
1636                            new_peer_id: Some(session.connection_id.into()),
1637                        },
1638                    )
1639                    .trace_err();
1640            }
1641
1642            broadcast(
1643                Some(session.connection_id),
1644                project
1645                    .collaborators
1646                    .iter()
1647                    .map(|collaborator| collaborator.connection_id),
1648                |connection_id| {
1649                    session.peer.forward_send(
1650                        session.connection_id,
1651                        connection_id,
1652                        proto::UpdateProject {
1653                            project_id: project.id.to_proto(),
1654                            worktrees: project.worktrees.clone(),
1655                        },
1656                    )
1657                },
1658            );
1659        }
1660
1661        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1662
1663        let rejoined_room = rejoined_room.into_inner();
1664
1665        room = rejoined_room.room;
1666        channel = rejoined_room.channel;
1667    }
1668
1669    if let Some(channel) = channel {
1670        channel_updated(
1671            &channel,
1672            &room,
1673            &session.peer,
1674            &*session.connection_pool().await,
1675        );
1676    }
1677
1678    update_user_contacts(session.user_id(), &session).await?;
1679    Ok(())
1680}
1681
1682fn notify_rejoined_projects(
1683    rejoined_projects: &mut Vec<RejoinedProject>,
1684    session: &UserSession,
1685) -> Result<()> {
1686    for project in rejoined_projects.iter() {
1687        for collaborator in &project.collaborators {
1688            session
1689                .peer
1690                .send(
1691                    collaborator.connection_id,
1692                    proto::UpdateProjectCollaborator {
1693                        project_id: project.id.to_proto(),
1694                        old_peer_id: Some(project.old_connection_id.into()),
1695                        new_peer_id: Some(session.connection_id.into()),
1696                    },
1697                )
1698                .trace_err();
1699        }
1700    }
1701
1702    for project in rejoined_projects {
1703        for worktree in mem::take(&mut project.worktrees) {
1704            #[cfg(any(test, feature = "test-support"))]
1705            const MAX_CHUNK_SIZE: usize = 2;
1706            #[cfg(not(any(test, feature = "test-support")))]
1707            const MAX_CHUNK_SIZE: usize = 256;
1708
1709            // Stream this worktree's entries.
1710            let message = proto::UpdateWorktree {
1711                project_id: project.id.to_proto(),
1712                worktree_id: worktree.id,
1713                abs_path: worktree.abs_path.clone(),
1714                root_name: worktree.root_name,
1715                updated_entries: worktree.updated_entries,
1716                removed_entries: worktree.removed_entries,
1717                scan_id: worktree.scan_id,
1718                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1719                updated_repositories: worktree.updated_repositories,
1720                removed_repositories: worktree.removed_repositories,
1721            };
1722            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1723                session.peer.send(session.connection_id, update.clone())?;
1724            }
1725
1726            // Stream this worktree's diagnostics.
1727            for summary in worktree.diagnostic_summaries {
1728                session.peer.send(
1729                    session.connection_id,
1730                    proto::UpdateDiagnosticSummary {
1731                        project_id: project.id.to_proto(),
1732                        worktree_id: worktree.id,
1733                        summary: Some(summary),
1734                    },
1735                )?;
1736            }
1737
1738            for settings_file in worktree.settings_files {
1739                session.peer.send(
1740                    session.connection_id,
1741                    proto::UpdateWorktreeSettings {
1742                        project_id: project.id.to_proto(),
1743                        worktree_id: worktree.id,
1744                        path: settings_file.path,
1745                        content: Some(settings_file.content),
1746                    },
1747                )?;
1748            }
1749        }
1750
1751        for language_server in &project.language_servers {
1752            session.peer.send(
1753                session.connection_id,
1754                proto::UpdateLanguageServer {
1755                    project_id: project.id.to_proto(),
1756                    language_server_id: language_server.id,
1757                    variant: Some(
1758                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1759                            proto::LspDiskBasedDiagnosticsUpdated {},
1760                        ),
1761                    ),
1762                },
1763            )?;
1764        }
1765    }
1766    Ok(())
1767}
1768
1769/// leave room disconnects from the room.
1770async fn leave_room(
1771    _: proto::LeaveRoom,
1772    response: Response<proto::LeaveRoom>,
1773    session: UserSession,
1774) -> Result<()> {
1775    leave_room_for_session(&session, session.connection_id).await?;
1776    response.send(proto::Ack {})?;
1777    Ok(())
1778}
1779
1780/// Updates the permissions of someone else in the room.
1781async fn set_room_participant_role(
1782    request: proto::SetRoomParticipantRole,
1783    response: Response<proto::SetRoomParticipantRole>,
1784    session: UserSession,
1785) -> Result<()> {
1786    let user_id = UserId::from_proto(request.user_id);
1787    let role = ChannelRole::from(request.role());
1788
1789    let (live_kit_room, can_publish) = {
1790        let room = session
1791            .db()
1792            .await
1793            .set_room_participant_role(
1794                session.user_id(),
1795                RoomId::from_proto(request.room_id),
1796                user_id,
1797                role,
1798            )
1799            .await?;
1800
1801        let live_kit_room = room.live_kit_room.clone();
1802        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1803        room_updated(&room, &session.peer);
1804        (live_kit_room, can_publish)
1805    };
1806
1807    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1808        live_kit
1809            .update_participant(
1810                live_kit_room.clone(),
1811                request.user_id.to_string(),
1812                live_kit_server::proto::ParticipantPermission {
1813                    can_subscribe: true,
1814                    can_publish,
1815                    can_publish_data: can_publish,
1816                    hidden: false,
1817                    recorder: false,
1818                },
1819            )
1820            .await
1821            .trace_err();
1822    }
1823
1824    response.send(proto::Ack {})?;
1825    Ok(())
1826}
1827
1828/// Call someone else into the current room
1829async fn call(
1830    request: proto::Call,
1831    response: Response<proto::Call>,
1832    session: UserSession,
1833) -> Result<()> {
1834    let room_id = RoomId::from_proto(request.room_id);
1835    let calling_user_id = session.user_id();
1836    let calling_connection_id = session.connection_id;
1837    let called_user_id = UserId::from_proto(request.called_user_id);
1838    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1839    if !session
1840        .db()
1841        .await
1842        .has_contact(calling_user_id, called_user_id)
1843        .await?
1844    {
1845        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1846    }
1847
1848    let incoming_call = {
1849        let (room, incoming_call) = &mut *session
1850            .db()
1851            .await
1852            .call(
1853                room_id,
1854                calling_user_id,
1855                calling_connection_id,
1856                called_user_id,
1857                initial_project_id,
1858            )
1859            .await?;
1860        room_updated(room, &session.peer);
1861        mem::take(incoming_call)
1862    };
1863    update_user_contacts(called_user_id, &session).await?;
1864
1865    let mut calls = session
1866        .connection_pool()
1867        .await
1868        .user_connection_ids(called_user_id)
1869        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1870        .collect::<FuturesUnordered<_>>();
1871
1872    while let Some(call_response) = calls.next().await {
1873        match call_response.as_ref() {
1874            Ok(_) => {
1875                response.send(proto::Ack {})?;
1876                return Ok(());
1877            }
1878            Err(_) => {
1879                call_response.trace_err();
1880            }
1881        }
1882    }
1883
1884    {
1885        let room = session
1886            .db()
1887            .await
1888            .call_failed(room_id, called_user_id)
1889            .await?;
1890        room_updated(&room, &session.peer);
1891    }
1892    update_user_contacts(called_user_id, &session).await?;
1893
1894    Err(anyhow!("failed to ring user"))?
1895}
1896
1897/// Cancel an outgoing call.
1898async fn cancel_call(
1899    request: proto::CancelCall,
1900    response: Response<proto::CancelCall>,
1901    session: UserSession,
1902) -> Result<()> {
1903    let called_user_id = UserId::from_proto(request.called_user_id);
1904    let room_id = RoomId::from_proto(request.room_id);
1905    {
1906        let room = session
1907            .db()
1908            .await
1909            .cancel_call(room_id, session.connection_id, called_user_id)
1910            .await?;
1911        room_updated(&room, &session.peer);
1912    }
1913
1914    for connection_id in session
1915        .connection_pool()
1916        .await
1917        .user_connection_ids(called_user_id)
1918    {
1919        session
1920            .peer
1921            .send(
1922                connection_id,
1923                proto::CallCanceled {
1924                    room_id: room_id.to_proto(),
1925                },
1926            )
1927            .trace_err();
1928    }
1929    response.send(proto::Ack {})?;
1930
1931    update_user_contacts(called_user_id, &session).await?;
1932    Ok(())
1933}
1934
1935/// Decline an incoming call.
1936async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1937    let room_id = RoomId::from_proto(message.room_id);
1938    {
1939        let room = session
1940            .db()
1941            .await
1942            .decline_call(Some(room_id), session.user_id())
1943            .await?
1944            .ok_or_else(|| anyhow!("failed to decline call"))?;
1945        room_updated(&room, &session.peer);
1946    }
1947
1948    for connection_id in session
1949        .connection_pool()
1950        .await
1951        .user_connection_ids(session.user_id())
1952    {
1953        session
1954            .peer
1955            .send(
1956                connection_id,
1957                proto::CallCanceled {
1958                    room_id: room_id.to_proto(),
1959                },
1960            )
1961            .trace_err();
1962    }
1963    update_user_contacts(session.user_id(), &session).await?;
1964    Ok(())
1965}
1966
1967/// Updates other participants in the room with your current location.
1968async fn update_participant_location(
1969    request: proto::UpdateParticipantLocation,
1970    response: Response<proto::UpdateParticipantLocation>,
1971    session: UserSession,
1972) -> Result<()> {
1973    let room_id = RoomId::from_proto(request.room_id);
1974    let location = request
1975        .location
1976        .ok_or_else(|| anyhow!("invalid location"))?;
1977
1978    let db = session.db().await;
1979    let room = db
1980        .update_room_participant_location(room_id, session.connection_id, location)
1981        .await?;
1982
1983    room_updated(&room, &session.peer);
1984    response.send(proto::Ack {})?;
1985    Ok(())
1986}
1987
1988/// Share a project into the room.
1989async fn share_project(
1990    request: proto::ShareProject,
1991    response: Response<proto::ShareProject>,
1992    session: UserSession,
1993) -> Result<()> {
1994    let (project_id, room) = &*session
1995        .db()
1996        .await
1997        .share_project(
1998            RoomId::from_proto(request.room_id),
1999            session.connection_id,
2000            &request.worktrees,
2001            request.is_ssh_project,
2002            request
2003                .dev_server_project_id
2004                .map(DevServerProjectId::from_proto),
2005        )
2006        .await?;
2007    response.send(proto::ShareProjectResponse {
2008        project_id: project_id.to_proto(),
2009    })?;
2010    room_updated(room, &session.peer);
2011
2012    Ok(())
2013}
2014
2015/// Unshare a project from the room.
2016async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2017    let project_id = ProjectId::from_proto(message.project_id);
2018    unshare_project_internal(
2019        project_id,
2020        session.connection_id,
2021        session.user_id(),
2022        &session,
2023    )
2024    .await
2025}
2026
2027async fn unshare_project_internal(
2028    project_id: ProjectId,
2029    connection_id: ConnectionId,
2030    user_id: Option<UserId>,
2031    session: &Session,
2032) -> Result<()> {
2033    let delete = {
2034        let room_guard = session
2035            .db()
2036            .await
2037            .unshare_project(project_id, connection_id, user_id)
2038            .await?;
2039
2040        let (delete, room, guest_connection_ids) = &*room_guard;
2041
2042        let message = proto::UnshareProject {
2043            project_id: project_id.to_proto(),
2044        };
2045
2046        broadcast(
2047            Some(connection_id),
2048            guest_connection_ids.iter().copied(),
2049            |conn_id| session.peer.send(conn_id, message.clone()),
2050        );
2051        if let Some(room) = room {
2052            room_updated(room, &session.peer);
2053        }
2054
2055        *delete
2056    };
2057
2058    if delete {
2059        let db = session.db().await;
2060        db.delete_project(project_id).await?;
2061    }
2062
2063    Ok(())
2064}
2065
2066/// DevServer makes a project available online
2067async fn share_dev_server_project(
2068    request: proto::ShareDevServerProject,
2069    response: Response<proto::ShareDevServerProject>,
2070    session: DevServerSession,
2071) -> Result<()> {
2072    let (dev_server_project, user_id, status) = session
2073        .db()
2074        .await
2075        .share_dev_server_project(
2076            DevServerProjectId::from_proto(request.dev_server_project_id),
2077            session.dev_server_id(),
2078            session.connection_id,
2079            &request.worktrees,
2080        )
2081        .await?;
2082    let Some(project_id) = dev_server_project.project_id else {
2083        return Err(anyhow!("failed to share remote project"))?;
2084    };
2085
2086    send_dev_server_projects_update(user_id, status, &session).await;
2087
2088    response.send(proto::ShareProjectResponse { project_id })?;
2089
2090    Ok(())
2091}
2092
2093/// Join someone elses shared project.
2094async fn join_project(
2095    request: proto::JoinProject,
2096    response: Response<proto::JoinProject>,
2097    session: UserSession,
2098) -> Result<()> {
2099    let project_id = ProjectId::from_proto(request.project_id);
2100
2101    tracing::info!(%project_id, "join project");
2102
2103    let db = session.db().await;
2104    let (project, replica_id) = &mut *db
2105        .join_project(project_id, session.connection_id, session.user_id())
2106        .await?;
2107    drop(db);
2108    tracing::info!(%project_id, "join remote project");
2109    join_project_internal(response, session, project, replica_id)
2110}
2111
2112trait JoinProjectInternalResponse {
2113    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2114}
2115impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2116    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2117        Response::<proto::JoinProject>::send(self, result)
2118    }
2119}
2120impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2121    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2122        Response::<proto::JoinHostedProject>::send(self, result)
2123    }
2124}
2125
2126fn join_project_internal(
2127    response: impl JoinProjectInternalResponse,
2128    session: UserSession,
2129    project: &mut Project,
2130    replica_id: &ReplicaId,
2131) -> Result<()> {
2132    let collaborators = project
2133        .collaborators
2134        .iter()
2135        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2136        .map(|collaborator| collaborator.to_proto())
2137        .collect::<Vec<_>>();
2138    let project_id = project.id;
2139    let guest_user_id = session.user_id();
2140
2141    let worktrees = project
2142        .worktrees
2143        .iter()
2144        .map(|(id, worktree)| proto::WorktreeMetadata {
2145            id: *id,
2146            root_name: worktree.root_name.clone(),
2147            visible: worktree.visible,
2148            abs_path: worktree.abs_path.clone(),
2149        })
2150        .collect::<Vec<_>>();
2151
2152    let add_project_collaborator = proto::AddProjectCollaborator {
2153        project_id: project_id.to_proto(),
2154        collaborator: Some(proto::Collaborator {
2155            peer_id: Some(session.connection_id.into()),
2156            replica_id: replica_id.0 as u32,
2157            user_id: guest_user_id.to_proto(),
2158        }),
2159    };
2160
2161    for collaborator in &collaborators {
2162        session
2163            .peer
2164            .send(
2165                collaborator.peer_id.unwrap().into(),
2166                add_project_collaborator.clone(),
2167            )
2168            .trace_err();
2169    }
2170
2171    // First, we send the metadata associated with each worktree.
2172    response.send(proto::JoinProjectResponse {
2173        project_id: project.id.0 as u64,
2174        worktrees: worktrees.clone(),
2175        replica_id: replica_id.0 as u32,
2176        collaborators: collaborators.clone(),
2177        language_servers: project.language_servers.clone(),
2178        role: project.role.into(),
2179        dev_server_project_id: project
2180            .dev_server_project_id
2181            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2182    })?;
2183
2184    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2185        #[cfg(any(test, feature = "test-support"))]
2186        const MAX_CHUNK_SIZE: usize = 2;
2187        #[cfg(not(any(test, feature = "test-support")))]
2188        const MAX_CHUNK_SIZE: usize = 256;
2189
2190        // Stream this worktree's entries.
2191        let message = proto::UpdateWorktree {
2192            project_id: project_id.to_proto(),
2193            worktree_id,
2194            abs_path: worktree.abs_path.clone(),
2195            root_name: worktree.root_name,
2196            updated_entries: worktree.entries,
2197            removed_entries: Default::default(),
2198            scan_id: worktree.scan_id,
2199            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2200            updated_repositories: worktree.repository_entries.into_values().collect(),
2201            removed_repositories: Default::default(),
2202        };
2203        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2204            session.peer.send(session.connection_id, update.clone())?;
2205        }
2206
2207        // Stream this worktree's diagnostics.
2208        for summary in worktree.diagnostic_summaries {
2209            session.peer.send(
2210                session.connection_id,
2211                proto::UpdateDiagnosticSummary {
2212                    project_id: project_id.to_proto(),
2213                    worktree_id: worktree.id,
2214                    summary: Some(summary),
2215                },
2216            )?;
2217        }
2218
2219        for settings_file in worktree.settings_files {
2220            session.peer.send(
2221                session.connection_id,
2222                proto::UpdateWorktreeSettings {
2223                    project_id: project_id.to_proto(),
2224                    worktree_id: worktree.id,
2225                    path: settings_file.path,
2226                    content: Some(settings_file.content),
2227                },
2228            )?;
2229        }
2230    }
2231
2232    for language_server in &project.language_servers {
2233        session.peer.send(
2234            session.connection_id,
2235            proto::UpdateLanguageServer {
2236                project_id: project_id.to_proto(),
2237                language_server_id: language_server.id,
2238                variant: Some(
2239                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2240                        proto::LspDiskBasedDiagnosticsUpdated {},
2241                    ),
2242                ),
2243            },
2244        )?;
2245    }
2246
2247    Ok(())
2248}
2249
2250/// Leave someone elses shared project.
2251async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2252    let sender_id = session.connection_id;
2253    let project_id = ProjectId::from_proto(request.project_id);
2254    let db = session.db().await;
2255    if db.is_hosted_project(project_id).await? {
2256        let project = db.leave_hosted_project(project_id, sender_id).await?;
2257        project_left(&project, &session);
2258        return Ok(());
2259    }
2260
2261    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2262    tracing::info!(
2263        %project_id,
2264        "leave project"
2265    );
2266
2267    project_left(project, &session);
2268    if let Some(room) = room {
2269        room_updated(room, &session.peer);
2270    }
2271
2272    Ok(())
2273}
2274
2275async fn join_hosted_project(
2276    request: proto::JoinHostedProject,
2277    response: Response<proto::JoinHostedProject>,
2278    session: UserSession,
2279) -> Result<()> {
2280    let (mut project, replica_id) = session
2281        .db()
2282        .await
2283        .join_hosted_project(
2284            ProjectId(request.project_id as i32),
2285            session.user_id(),
2286            session.connection_id,
2287        )
2288        .await?;
2289
2290    join_project_internal(response, session, &mut project, &replica_id)
2291}
2292
2293async fn list_remote_directory(
2294    request: proto::ListRemoteDirectory,
2295    response: Response<proto::ListRemoteDirectory>,
2296    session: UserSession,
2297) -> Result<()> {
2298    let dev_server_id = DevServerId(request.dev_server_id as i32);
2299    let dev_server_connection_id = session
2300        .connection_pool()
2301        .await
2302        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2303
2304    session
2305        .db()
2306        .await
2307        .get_dev_server_for_user(dev_server_id, session.user_id())
2308        .await?;
2309
2310    response.send(
2311        session
2312            .peer
2313            .forward_request(session.connection_id, dev_server_connection_id, request)
2314            .await?,
2315    )?;
2316    Ok(())
2317}
2318
2319async fn update_dev_server_project(
2320    request: proto::UpdateDevServerProject,
2321    response: Response<proto::UpdateDevServerProject>,
2322    session: UserSession,
2323) -> Result<()> {
2324    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2325
2326    let (dev_server_project, update) = session
2327        .db()
2328        .await
2329        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2330        .await?;
2331
2332    let projects = session
2333        .db()
2334        .await
2335        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2336        .await?;
2337
2338    let dev_server_connection_id = session
2339        .connection_pool()
2340        .await
2341        .dev_server_connection_id_supporting(
2342            dev_server_project.dev_server_id,
2343            ZedVersion::with_list_directory(),
2344        )?;
2345
2346    session.peer.send(
2347        dev_server_connection_id,
2348        proto::DevServerInstructions { projects },
2349    )?;
2350
2351    send_dev_server_projects_update(session.user_id(), update, &session).await;
2352
2353    response.send(proto::Ack {})
2354}
2355
2356async fn create_dev_server_project(
2357    request: proto::CreateDevServerProject,
2358    response: Response<proto::CreateDevServerProject>,
2359    session: UserSession,
2360) -> Result<()> {
2361    let dev_server_id = DevServerId(request.dev_server_id as i32);
2362    let dev_server_connection_id = session
2363        .connection_pool()
2364        .await
2365        .dev_server_connection_id(dev_server_id);
2366    let Some(dev_server_connection_id) = dev_server_connection_id else {
2367        Err(ErrorCode::DevServerOffline
2368            .message("Cannot create a remote project when the dev server is offline".to_string())
2369            .anyhow())?
2370    };
2371
2372    let path = request.path.clone();
2373    //Check that the path exists on the dev server
2374    session
2375        .peer
2376        .forward_request(
2377            session.connection_id,
2378            dev_server_connection_id,
2379            proto::ValidateDevServerProjectRequest { path: path.clone() },
2380        )
2381        .await?;
2382
2383    let (dev_server_project, update) = session
2384        .db()
2385        .await
2386        .create_dev_server_project(
2387            DevServerId(request.dev_server_id as i32),
2388            &request.path,
2389            session.user_id(),
2390        )
2391        .await?;
2392
2393    let projects = session
2394        .db()
2395        .await
2396        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2397        .await?;
2398
2399    session.peer.send(
2400        dev_server_connection_id,
2401        proto::DevServerInstructions { projects },
2402    )?;
2403
2404    send_dev_server_projects_update(session.user_id(), update, &session).await;
2405
2406    response.send(proto::CreateDevServerProjectResponse {
2407        dev_server_project: Some(dev_server_project.to_proto(None)),
2408    })?;
2409    Ok(())
2410}
2411
2412async fn create_dev_server(
2413    request: proto::CreateDevServer,
2414    response: Response<proto::CreateDevServer>,
2415    session: UserSession,
2416) -> Result<()> {
2417    let access_token = auth::random_token();
2418    let hashed_access_token = auth::hash_access_token(&access_token);
2419
2420    if request.name.is_empty() {
2421        return Err(proto::ErrorCode::Forbidden
2422            .message("Dev server name cannot be empty".to_string())
2423            .anyhow())?;
2424    }
2425
2426    let (dev_server, status) = session
2427        .db()
2428        .await
2429        .create_dev_server(
2430            &request.name,
2431            request.ssh_connection_string.as_deref(),
2432            &hashed_access_token,
2433            session.user_id(),
2434        )
2435        .await?;
2436
2437    send_dev_server_projects_update(session.user_id(), status, &session).await;
2438
2439    response.send(proto::CreateDevServerResponse {
2440        dev_server_id: dev_server.id.0 as u64,
2441        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2442        name: request.name,
2443    })?;
2444    Ok(())
2445}
2446
2447async fn regenerate_dev_server_token(
2448    request: proto::RegenerateDevServerToken,
2449    response: Response<proto::RegenerateDevServerToken>,
2450    session: UserSession,
2451) -> Result<()> {
2452    let dev_server_id = DevServerId(request.dev_server_id as i32);
2453    let access_token = auth::random_token();
2454    let hashed_access_token = auth::hash_access_token(&access_token);
2455
2456    let connection_id = session
2457        .connection_pool()
2458        .await
2459        .dev_server_connection_id(dev_server_id);
2460    if let Some(connection_id) = connection_id {
2461        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2462        session.peer.send(
2463            connection_id,
2464            proto::ShutdownDevServer {
2465                reason: Some("dev server token was regenerated".to_string()),
2466            },
2467        )?;
2468        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2469    }
2470
2471    let status = session
2472        .db()
2473        .await
2474        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2475        .await?;
2476
2477    send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479    response.send(proto::RegenerateDevServerTokenResponse {
2480        dev_server_id: dev_server_id.to_proto(),
2481        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2482    })?;
2483    Ok(())
2484}
2485
2486async fn rename_dev_server(
2487    request: proto::RenameDevServer,
2488    response: Response<proto::RenameDevServer>,
2489    session: UserSession,
2490) -> Result<()> {
2491    if request.name.trim().is_empty() {
2492        return Err(proto::ErrorCode::Forbidden
2493            .message("Dev server name cannot be empty".to_string())
2494            .anyhow())?;
2495    }
2496
2497    let dev_server_id = DevServerId(request.dev_server_id as i32);
2498    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2499    if dev_server.user_id != session.user_id() {
2500        return Err(anyhow!(ErrorCode::Forbidden))?;
2501    }
2502
2503    let status = session
2504        .db()
2505        .await
2506        .rename_dev_server(
2507            dev_server_id,
2508            &request.name,
2509            request.ssh_connection_string.as_deref(),
2510            session.user_id(),
2511        )
2512        .await?;
2513
2514    send_dev_server_projects_update(session.user_id(), status, &session).await;
2515
2516    response.send(proto::Ack {})?;
2517    Ok(())
2518}
2519
2520async fn delete_dev_server(
2521    request: proto::DeleteDevServer,
2522    response: Response<proto::DeleteDevServer>,
2523    session: UserSession,
2524) -> Result<()> {
2525    let dev_server_id = DevServerId(request.dev_server_id as i32);
2526    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2527    if dev_server.user_id != session.user_id() {
2528        return Err(anyhow!(ErrorCode::Forbidden))?;
2529    }
2530
2531    let connection_id = session
2532        .connection_pool()
2533        .await
2534        .dev_server_connection_id(dev_server_id);
2535    if let Some(connection_id) = connection_id {
2536        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2537        session.peer.send(
2538            connection_id,
2539            proto::ShutdownDevServer {
2540                reason: Some("dev server was deleted".to_string()),
2541            },
2542        )?;
2543        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2544    }
2545
2546    let status = session
2547        .db()
2548        .await
2549        .delete_dev_server(dev_server_id, session.user_id())
2550        .await?;
2551
2552    send_dev_server_projects_update(session.user_id(), status, &session).await;
2553
2554    response.send(proto::Ack {})?;
2555    Ok(())
2556}
2557
2558async fn delete_dev_server_project(
2559    request: proto::DeleteDevServerProject,
2560    response: Response<proto::DeleteDevServerProject>,
2561    session: UserSession,
2562) -> Result<()> {
2563    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2564    let dev_server_project = session
2565        .db()
2566        .await
2567        .get_dev_server_project(dev_server_project_id)
2568        .await?;
2569
2570    let dev_server = session
2571        .db()
2572        .await
2573        .get_dev_server(dev_server_project.dev_server_id)
2574        .await?;
2575    if dev_server.user_id != session.user_id() {
2576        return Err(anyhow!(ErrorCode::Forbidden))?;
2577    }
2578
2579    let dev_server_connection_id = session
2580        .connection_pool()
2581        .await
2582        .dev_server_connection_id(dev_server.id);
2583
2584    if let Some(dev_server_connection_id) = dev_server_connection_id {
2585        let project = session
2586            .db()
2587            .await
2588            .find_dev_server_project(dev_server_project_id)
2589            .await;
2590        if let Ok(project) = project {
2591            unshare_project_internal(
2592                project.id,
2593                dev_server_connection_id,
2594                Some(session.user_id()),
2595                &session,
2596            )
2597            .await?;
2598        }
2599    }
2600
2601    let (projects, status) = session
2602        .db()
2603        .await
2604        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2605        .await?;
2606
2607    if let Some(dev_server_connection_id) = dev_server_connection_id {
2608        session.peer.send(
2609            dev_server_connection_id,
2610            proto::DevServerInstructions { projects },
2611        )?;
2612    }
2613
2614    send_dev_server_projects_update(session.user_id(), status, &session).await;
2615
2616    response.send(proto::Ack {})?;
2617    Ok(())
2618}
2619
2620async fn rejoin_dev_server_projects(
2621    request: proto::RejoinRemoteProjects,
2622    response: Response<proto::RejoinRemoteProjects>,
2623    session: UserSession,
2624) -> Result<()> {
2625    let mut rejoined_projects = {
2626        let db = session.db().await;
2627        db.rejoin_dev_server_projects(
2628            &request.rejoined_projects,
2629            session.user_id(),
2630            session.0.connection_id,
2631        )
2632        .await?
2633    };
2634    response.send(proto::RejoinRemoteProjectsResponse {
2635        rejoined_projects: rejoined_projects
2636            .iter()
2637            .map(|project| project.to_proto())
2638            .collect(),
2639    })?;
2640    notify_rejoined_projects(&mut rejoined_projects, &session)
2641}
2642
2643async fn reconnect_dev_server(
2644    request: proto::ReconnectDevServer,
2645    response: Response<proto::ReconnectDevServer>,
2646    session: DevServerSession,
2647) -> Result<()> {
2648    let reshared_projects = {
2649        let db = session.db().await;
2650        db.reshare_dev_server_projects(
2651            &request.reshared_projects,
2652            session.dev_server_id(),
2653            session.0.connection_id,
2654        )
2655        .await?
2656    };
2657
2658    for project in &reshared_projects {
2659        for collaborator in &project.collaborators {
2660            session
2661                .peer
2662                .send(
2663                    collaborator.connection_id,
2664                    proto::UpdateProjectCollaborator {
2665                        project_id: project.id.to_proto(),
2666                        old_peer_id: Some(project.old_connection_id.into()),
2667                        new_peer_id: Some(session.connection_id.into()),
2668                    },
2669                )
2670                .trace_err();
2671        }
2672
2673        broadcast(
2674            Some(session.connection_id),
2675            project
2676                .collaborators
2677                .iter()
2678                .map(|collaborator| collaborator.connection_id),
2679            |connection_id| {
2680                session.peer.forward_send(
2681                    session.connection_id,
2682                    connection_id,
2683                    proto::UpdateProject {
2684                        project_id: project.id.to_proto(),
2685                        worktrees: project.worktrees.clone(),
2686                    },
2687                )
2688            },
2689        );
2690    }
2691
2692    response.send(proto::ReconnectDevServerResponse {
2693        reshared_projects: reshared_projects
2694            .iter()
2695            .map(|project| proto::ResharedProject {
2696                id: project.id.to_proto(),
2697                collaborators: project
2698                    .collaborators
2699                    .iter()
2700                    .map(|collaborator| collaborator.to_proto())
2701                    .collect(),
2702            })
2703            .collect(),
2704    })?;
2705
2706    Ok(())
2707}
2708
2709async fn shutdown_dev_server(
2710    _: proto::ShutdownDevServer,
2711    response: Response<proto::ShutdownDevServer>,
2712    session: DevServerSession,
2713) -> Result<()> {
2714    response.send(proto::Ack {})?;
2715    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2716    remove_dev_server_connection(session.dev_server_id(), &session).await
2717}
2718
2719async fn shutdown_dev_server_internal(
2720    dev_server_id: DevServerId,
2721    connection_id: ConnectionId,
2722    session: &Session,
2723) -> Result<()> {
2724    let (dev_server_projects, dev_server) = {
2725        let db = session.db().await;
2726        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2727        let dev_server = db.get_dev_server(dev_server_id).await?;
2728        (dev_server_projects, dev_server)
2729    };
2730
2731    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2732        unshare_project_internal(
2733            ProjectId::from_proto(project_id),
2734            connection_id,
2735            None,
2736            session,
2737        )
2738        .await?;
2739    }
2740
2741    session
2742        .connection_pool()
2743        .await
2744        .set_dev_server_offline(dev_server_id);
2745
2746    let status = session
2747        .db()
2748        .await
2749        .dev_server_projects_update(dev_server.user_id)
2750        .await?;
2751    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2752
2753    Ok(())
2754}
2755
2756async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2757    let dev_server_connection = session
2758        .connection_pool()
2759        .await
2760        .dev_server_connection_id(dev_server_id);
2761
2762    if let Some(dev_server_connection) = dev_server_connection {
2763        session
2764            .connection_pool()
2765            .await
2766            .remove_connection(dev_server_connection)?;
2767    }
2768    Ok(())
2769}
2770
2771/// Updates other participants with changes to the project
2772async fn update_project(
2773    request: proto::UpdateProject,
2774    response: Response<proto::UpdateProject>,
2775    session: Session,
2776) -> Result<()> {
2777    let project_id = ProjectId::from_proto(request.project_id);
2778    let (room, guest_connection_ids) = &*session
2779        .db()
2780        .await
2781        .update_project(project_id, session.connection_id, &request.worktrees)
2782        .await?;
2783    broadcast(
2784        Some(session.connection_id),
2785        guest_connection_ids.iter().copied(),
2786        |connection_id| {
2787            session
2788                .peer
2789                .forward_send(session.connection_id, connection_id, request.clone())
2790        },
2791    );
2792    if let Some(room) = room {
2793        room_updated(room, &session.peer);
2794    }
2795    response.send(proto::Ack {})?;
2796
2797    Ok(())
2798}
2799
2800/// Updates other participants with changes to the worktree
2801async fn update_worktree(
2802    request: proto::UpdateWorktree,
2803    response: Response<proto::UpdateWorktree>,
2804    session: Session,
2805) -> Result<()> {
2806    let guest_connection_ids = session
2807        .db()
2808        .await
2809        .update_worktree(&request, session.connection_id)
2810        .await?;
2811
2812    broadcast(
2813        Some(session.connection_id),
2814        guest_connection_ids.iter().copied(),
2815        |connection_id| {
2816            session
2817                .peer
2818                .forward_send(session.connection_id, connection_id, request.clone())
2819        },
2820    );
2821    response.send(proto::Ack {})?;
2822    Ok(())
2823}
2824
2825/// Updates other participants with changes to the diagnostics
2826async fn update_diagnostic_summary(
2827    message: proto::UpdateDiagnosticSummary,
2828    session: Session,
2829) -> Result<()> {
2830    let guest_connection_ids = session
2831        .db()
2832        .await
2833        .update_diagnostic_summary(&message, session.connection_id)
2834        .await?;
2835
2836    broadcast(
2837        Some(session.connection_id),
2838        guest_connection_ids.iter().copied(),
2839        |connection_id| {
2840            session
2841                .peer
2842                .forward_send(session.connection_id, connection_id, message.clone())
2843        },
2844    );
2845
2846    Ok(())
2847}
2848
2849/// Updates other participants with changes to the worktree settings
2850async fn update_worktree_settings(
2851    message: proto::UpdateWorktreeSettings,
2852    session: Session,
2853) -> Result<()> {
2854    let guest_connection_ids = session
2855        .db()
2856        .await
2857        .update_worktree_settings(&message, session.connection_id)
2858        .await?;
2859
2860    broadcast(
2861        Some(session.connection_id),
2862        guest_connection_ids.iter().copied(),
2863        |connection_id| {
2864            session
2865                .peer
2866                .forward_send(session.connection_id, connection_id, message.clone())
2867        },
2868    );
2869
2870    Ok(())
2871}
2872
2873/// Notify other participants that a  language server has started.
2874async fn start_language_server(
2875    request: proto::StartLanguageServer,
2876    session: Session,
2877) -> Result<()> {
2878    let guest_connection_ids = session
2879        .db()
2880        .await
2881        .start_language_server(&request, session.connection_id)
2882        .await?;
2883
2884    broadcast(
2885        Some(session.connection_id),
2886        guest_connection_ids.iter().copied(),
2887        |connection_id| {
2888            session
2889                .peer
2890                .forward_send(session.connection_id, connection_id, request.clone())
2891        },
2892    );
2893    Ok(())
2894}
2895
2896/// Notify other participants that a language server has changed.
2897async fn update_language_server(
2898    request: proto::UpdateLanguageServer,
2899    session: Session,
2900) -> Result<()> {
2901    let project_id = ProjectId::from_proto(request.project_id);
2902    let project_connection_ids = session
2903        .db()
2904        .await
2905        .project_connection_ids(project_id, session.connection_id, true)
2906        .await?;
2907    broadcast(
2908        Some(session.connection_id),
2909        project_connection_ids.iter().copied(),
2910        |connection_id| {
2911            session
2912                .peer
2913                .forward_send(session.connection_id, connection_id, request.clone())
2914        },
2915    );
2916    Ok(())
2917}
2918
2919/// forward a project request to the host. These requests should be read only
2920/// as guests are allowed to send them.
2921async fn forward_read_only_project_request<T>(
2922    request: T,
2923    response: Response<T>,
2924    session: UserSession,
2925) -> Result<()>
2926where
2927    T: EntityMessage + RequestMessage,
2928{
2929    let project_id = ProjectId::from_proto(request.remote_entity_id());
2930    let host_connection_id = session
2931        .db()
2932        .await
2933        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2934        .await?;
2935    let payload = session
2936        .peer
2937        .forward_request(session.connection_id, host_connection_id, request)
2938        .await?;
2939    response.send(payload)?;
2940    Ok(())
2941}
2942
2943async fn forward_find_search_candidates_request(
2944    request: proto::FindSearchCandidates,
2945    response: Response<proto::FindSearchCandidates>,
2946    session: UserSession,
2947) -> Result<()> {
2948    let project_id = ProjectId::from_proto(request.remote_entity_id());
2949    let host_connection_id = session
2950        .db()
2951        .await
2952        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2953        .await?;
2954
2955    let host_version = session
2956        .connection_pool()
2957        .await
2958        .connection(host_connection_id)
2959        .map(|c| c.zed_version);
2960
2961    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2962    {
2963        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2964        let search = proto::SearchProject {
2965            project_id: project_id.to_proto(),
2966            query: query.query,
2967            regex: query.regex,
2968            whole_word: query.whole_word,
2969            case_sensitive: query.case_sensitive,
2970            files_to_include: query.files_to_include,
2971            files_to_exclude: query.files_to_exclude,
2972            include_ignored: query.include_ignored,
2973        };
2974
2975        let payload = session
2976            .peer
2977            .forward_request(session.connection_id, host_connection_id, search)
2978            .await?;
2979        return response.send(proto::FindSearchCandidatesResponse {
2980            buffer_ids: payload
2981                .locations
2982                .into_iter()
2983                .map(|loc| loc.buffer_id)
2984                .collect(),
2985        });
2986    }
2987
2988    let payload = session
2989        .peer
2990        .forward_request(session.connection_id, host_connection_id, request)
2991        .await?;
2992    response.send(payload)?;
2993    Ok(())
2994}
2995
2996/// forward a project request to the dev server. Only allowed
2997/// if it's your dev server.
2998async fn forward_project_request_for_owner<T>(
2999    request: T,
3000    response: Response<T>,
3001    session: UserSession,
3002) -> Result<()>
3003where
3004    T: EntityMessage + RequestMessage,
3005{
3006    let project_id = ProjectId::from_proto(request.remote_entity_id());
3007
3008    let host_connection_id = session
3009        .db()
3010        .await
3011        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3012        .await?;
3013    let payload = session
3014        .peer
3015        .forward_request(session.connection_id, host_connection_id, request)
3016        .await?;
3017    response.send(payload)?;
3018    Ok(())
3019}
3020
3021/// forward a project request to the host. These requests are disallowed
3022/// for guests.
3023async fn forward_mutating_project_request<T>(
3024    request: T,
3025    response: Response<T>,
3026    session: UserSession,
3027) -> Result<()>
3028where
3029    T: EntityMessage + RequestMessage,
3030{
3031    let project_id = ProjectId::from_proto(request.remote_entity_id());
3032
3033    let host_connection_id = session
3034        .db()
3035        .await
3036        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3037        .await?;
3038    let payload = session
3039        .peer
3040        .forward_request(session.connection_id, host_connection_id, request)
3041        .await?;
3042    response.send(payload)?;
3043    Ok(())
3044}
3045
3046/// Notify other participants that a new buffer has been created
3047async fn create_buffer_for_peer(
3048    request: proto::CreateBufferForPeer,
3049    session: Session,
3050) -> Result<()> {
3051    session
3052        .db()
3053        .await
3054        .check_user_is_project_host(
3055            ProjectId::from_proto(request.project_id),
3056            session.connection_id,
3057        )
3058        .await?;
3059    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3060    session
3061        .peer
3062        .forward_send(session.connection_id, peer_id.into(), request)?;
3063    Ok(())
3064}
3065
3066/// Notify other participants that a buffer has been updated. This is
3067/// allowed for guests as long as the update is limited to selections.
3068async fn update_buffer(
3069    request: proto::UpdateBuffer,
3070    response: Response<proto::UpdateBuffer>,
3071    session: Session,
3072) -> Result<()> {
3073    let project_id = ProjectId::from_proto(request.project_id);
3074    let mut capability = Capability::ReadOnly;
3075
3076    for op in request.operations.iter() {
3077        match op.variant {
3078            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3079            Some(_) => capability = Capability::ReadWrite,
3080        }
3081    }
3082
3083    let host = {
3084        let guard = session
3085            .db()
3086            .await
3087            .connections_for_buffer_update(
3088                project_id,
3089                session.principal_id(),
3090                session.connection_id,
3091                capability,
3092            )
3093            .await?;
3094
3095        let (host, guests) = &*guard;
3096
3097        broadcast(
3098            Some(session.connection_id),
3099            guests.clone(),
3100            |connection_id| {
3101                session
3102                    .peer
3103                    .forward_send(session.connection_id, connection_id, request.clone())
3104            },
3105        );
3106
3107        *host
3108    };
3109
3110    if host != session.connection_id {
3111        session
3112            .peer
3113            .forward_request(session.connection_id, host, request.clone())
3114            .await?;
3115    }
3116
3117    response.send(proto::Ack {})?;
3118    Ok(())
3119}
3120
3121async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3122    let project_id = ProjectId::from_proto(message.project_id);
3123
3124    let operation = message.operation.as_ref().context("invalid operation")?;
3125    let capability = match operation.variant.as_ref() {
3126        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3127            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3128                match buffer_op.variant {
3129                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3130                        Capability::ReadOnly
3131                    }
3132                    _ => Capability::ReadWrite,
3133                }
3134            } else {
3135                Capability::ReadWrite
3136            }
3137        }
3138        Some(_) => Capability::ReadWrite,
3139        None => Capability::ReadOnly,
3140    };
3141
3142    let guard = session
3143        .db()
3144        .await
3145        .connections_for_buffer_update(
3146            project_id,
3147            session.principal_id(),
3148            session.connection_id,
3149            capability,
3150        )
3151        .await?;
3152
3153    let (host, guests) = &*guard;
3154
3155    broadcast(
3156        Some(session.connection_id),
3157        guests.iter().chain([host]).copied(),
3158        |connection_id| {
3159            session
3160                .peer
3161                .forward_send(session.connection_id, connection_id, message.clone())
3162        },
3163    );
3164
3165    Ok(())
3166}
3167
3168/// Notify other participants that a project has been updated.
3169async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3170    request: T,
3171    session: Session,
3172) -> Result<()> {
3173    let project_id = ProjectId::from_proto(request.remote_entity_id());
3174    let project_connection_ids = session
3175        .db()
3176        .await
3177        .project_connection_ids(project_id, session.connection_id, false)
3178        .await?;
3179
3180    broadcast(
3181        Some(session.connection_id),
3182        project_connection_ids.iter().copied(),
3183        |connection_id| {
3184            session
3185                .peer
3186                .forward_send(session.connection_id, connection_id, request.clone())
3187        },
3188    );
3189    Ok(())
3190}
3191
3192/// Start following another user in a call.
3193async fn follow(
3194    request: proto::Follow,
3195    response: Response<proto::Follow>,
3196    session: UserSession,
3197) -> Result<()> {
3198    let room_id = RoomId::from_proto(request.room_id);
3199    let project_id = request.project_id.map(ProjectId::from_proto);
3200    let leader_id = request
3201        .leader_id
3202        .ok_or_else(|| anyhow!("invalid leader id"))?
3203        .into();
3204    let follower_id = session.connection_id;
3205
3206    session
3207        .db()
3208        .await
3209        .check_room_participants(room_id, leader_id, session.connection_id)
3210        .await?;
3211
3212    let response_payload = session
3213        .peer
3214        .forward_request(session.connection_id, leader_id, request)
3215        .await?;
3216    response.send(response_payload)?;
3217
3218    if let Some(project_id) = project_id {
3219        let room = session
3220            .db()
3221            .await
3222            .follow(room_id, project_id, leader_id, follower_id)
3223            .await?;
3224        room_updated(&room, &session.peer);
3225    }
3226
3227    Ok(())
3228}
3229
3230/// Stop following another user in a call.
3231async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3232    let room_id = RoomId::from_proto(request.room_id);
3233    let project_id = request.project_id.map(ProjectId::from_proto);
3234    let leader_id = request
3235        .leader_id
3236        .ok_or_else(|| anyhow!("invalid leader id"))?
3237        .into();
3238    let follower_id = session.connection_id;
3239
3240    session
3241        .db()
3242        .await
3243        .check_room_participants(room_id, leader_id, session.connection_id)
3244        .await?;
3245
3246    session
3247        .peer
3248        .forward_send(session.connection_id, leader_id, request)?;
3249
3250    if let Some(project_id) = project_id {
3251        let room = session
3252            .db()
3253            .await
3254            .unfollow(room_id, project_id, leader_id, follower_id)
3255            .await?;
3256        room_updated(&room, &session.peer);
3257    }
3258
3259    Ok(())
3260}
3261
3262/// Notify everyone following you of your current location.
3263async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3264    let room_id = RoomId::from_proto(request.room_id);
3265    let database = session.db.lock().await;
3266
3267    let connection_ids = if let Some(project_id) = request.project_id {
3268        let project_id = ProjectId::from_proto(project_id);
3269        database
3270            .project_connection_ids(project_id, session.connection_id, true)
3271            .await?
3272    } else {
3273        database
3274            .room_connection_ids(room_id, session.connection_id)
3275            .await?
3276    };
3277
3278    // For now, don't send view update messages back to that view's current leader.
3279    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3280        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3281        _ => None,
3282    });
3283
3284    for connection_id in connection_ids.iter().cloned() {
3285        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3286            session
3287                .peer
3288                .forward_send(session.connection_id, connection_id, request.clone())?;
3289        }
3290    }
3291    Ok(())
3292}
3293
3294/// Get public data about users.
3295async fn get_users(
3296    request: proto::GetUsers,
3297    response: Response<proto::GetUsers>,
3298    session: Session,
3299) -> Result<()> {
3300    let user_ids = request
3301        .user_ids
3302        .into_iter()
3303        .map(UserId::from_proto)
3304        .collect();
3305    let users = session
3306        .db()
3307        .await
3308        .get_users_by_ids(user_ids)
3309        .await?
3310        .into_iter()
3311        .map(|user| proto::User {
3312            id: user.id.to_proto(),
3313            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3314            github_login: user.github_login,
3315        })
3316        .collect();
3317    response.send(proto::UsersResponse { users })?;
3318    Ok(())
3319}
3320
3321/// Search for users (to invite) buy Github login
3322async fn fuzzy_search_users(
3323    request: proto::FuzzySearchUsers,
3324    response: Response<proto::FuzzySearchUsers>,
3325    session: UserSession,
3326) -> Result<()> {
3327    let query = request.query;
3328    let users = match query.len() {
3329        0 => vec![],
3330        1 | 2 => session
3331            .db()
3332            .await
3333            .get_user_by_github_login(&query)
3334            .await?
3335            .into_iter()
3336            .collect(),
3337        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3338    };
3339    let users = users
3340        .into_iter()
3341        .filter(|user| user.id != session.user_id())
3342        .map(|user| proto::User {
3343            id: user.id.to_proto(),
3344            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3345            github_login: user.github_login,
3346        })
3347        .collect();
3348    response.send(proto::UsersResponse { users })?;
3349    Ok(())
3350}
3351
3352/// Send a contact request to another user.
3353async fn request_contact(
3354    request: proto::RequestContact,
3355    response: Response<proto::RequestContact>,
3356    session: UserSession,
3357) -> Result<()> {
3358    let requester_id = session.user_id();
3359    let responder_id = UserId::from_proto(request.responder_id);
3360    if requester_id == responder_id {
3361        return Err(anyhow!("cannot add yourself as a contact"))?;
3362    }
3363
3364    let notifications = session
3365        .db()
3366        .await
3367        .send_contact_request(requester_id, responder_id)
3368        .await?;
3369
3370    // Update outgoing contact requests of requester
3371    let mut update = proto::UpdateContacts::default();
3372    update.outgoing_requests.push(responder_id.to_proto());
3373    for connection_id in session
3374        .connection_pool()
3375        .await
3376        .user_connection_ids(requester_id)
3377    {
3378        session.peer.send(connection_id, update.clone())?;
3379    }
3380
3381    // Update incoming contact requests of responder
3382    let mut update = proto::UpdateContacts::default();
3383    update
3384        .incoming_requests
3385        .push(proto::IncomingContactRequest {
3386            requester_id: requester_id.to_proto(),
3387        });
3388    let connection_pool = session.connection_pool().await;
3389    for connection_id in connection_pool.user_connection_ids(responder_id) {
3390        session.peer.send(connection_id, update.clone())?;
3391    }
3392
3393    send_notifications(&connection_pool, &session.peer, notifications);
3394
3395    response.send(proto::Ack {})?;
3396    Ok(())
3397}
3398
3399/// Accept or decline a contact request
3400async fn respond_to_contact_request(
3401    request: proto::RespondToContactRequest,
3402    response: Response<proto::RespondToContactRequest>,
3403    session: UserSession,
3404) -> Result<()> {
3405    let responder_id = session.user_id();
3406    let requester_id = UserId::from_proto(request.requester_id);
3407    let db = session.db().await;
3408    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3409        db.dismiss_contact_notification(responder_id, requester_id)
3410            .await?;
3411    } else {
3412        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3413
3414        let notifications = db
3415            .respond_to_contact_request(responder_id, requester_id, accept)
3416            .await?;
3417        let requester_busy = db.is_user_busy(requester_id).await?;
3418        let responder_busy = db.is_user_busy(responder_id).await?;
3419
3420        let pool = session.connection_pool().await;
3421        // Update responder with new contact
3422        let mut update = proto::UpdateContacts::default();
3423        if accept {
3424            update
3425                .contacts
3426                .push(contact_for_user(requester_id, requester_busy, &pool));
3427        }
3428        update
3429            .remove_incoming_requests
3430            .push(requester_id.to_proto());
3431        for connection_id in pool.user_connection_ids(responder_id) {
3432            session.peer.send(connection_id, update.clone())?;
3433        }
3434
3435        // Update requester with new contact
3436        let mut update = proto::UpdateContacts::default();
3437        if accept {
3438            update
3439                .contacts
3440                .push(contact_for_user(responder_id, responder_busy, &pool));
3441        }
3442        update
3443            .remove_outgoing_requests
3444            .push(responder_id.to_proto());
3445
3446        for connection_id in pool.user_connection_ids(requester_id) {
3447            session.peer.send(connection_id, update.clone())?;
3448        }
3449
3450        send_notifications(&pool, &session.peer, notifications);
3451    }
3452
3453    response.send(proto::Ack {})?;
3454    Ok(())
3455}
3456
3457/// Remove a contact.
3458async fn remove_contact(
3459    request: proto::RemoveContact,
3460    response: Response<proto::RemoveContact>,
3461    session: UserSession,
3462) -> Result<()> {
3463    let requester_id = session.user_id();
3464    let responder_id = UserId::from_proto(request.user_id);
3465    let db = session.db().await;
3466    let (contact_accepted, deleted_notification_id) =
3467        db.remove_contact(requester_id, responder_id).await?;
3468
3469    let pool = session.connection_pool().await;
3470    // Update outgoing contact requests of requester
3471    let mut update = proto::UpdateContacts::default();
3472    if contact_accepted {
3473        update.remove_contacts.push(responder_id.to_proto());
3474    } else {
3475        update
3476            .remove_outgoing_requests
3477            .push(responder_id.to_proto());
3478    }
3479    for connection_id in pool.user_connection_ids(requester_id) {
3480        session.peer.send(connection_id, update.clone())?;
3481    }
3482
3483    // Update incoming contact requests of responder
3484    let mut update = proto::UpdateContacts::default();
3485    if contact_accepted {
3486        update.remove_contacts.push(requester_id.to_proto());
3487    } else {
3488        update
3489            .remove_incoming_requests
3490            .push(requester_id.to_proto());
3491    }
3492    for connection_id in pool.user_connection_ids(responder_id) {
3493        session.peer.send(connection_id, update.clone())?;
3494        if let Some(notification_id) = deleted_notification_id {
3495            session.peer.send(
3496                connection_id,
3497                proto::DeleteNotification {
3498                    notification_id: notification_id.to_proto(),
3499                },
3500            )?;
3501        }
3502    }
3503
3504    response.send(proto::Ack {})?;
3505    Ok(())
3506}
3507
3508fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3509    version.0.minor() < 139
3510}
3511
3512async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3513    let plan = session.current_plan(session.db().await).await?;
3514
3515    session
3516        .peer
3517        .send(
3518            session.connection_id,
3519            proto::UpdateUserPlan { plan: plan.into() },
3520        )
3521        .trace_err();
3522
3523    Ok(())
3524}
3525
3526async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3527    subscribe_user_to_channels(
3528        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3529        &session,
3530    )
3531    .await?;
3532    Ok(())
3533}
3534
3535async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3536    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3537    let mut pool = session.connection_pool().await;
3538    for membership in &channels_for_user.channel_memberships {
3539        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3540    }
3541    session.peer.send(
3542        session.connection_id,
3543        build_update_user_channels(&channels_for_user),
3544    )?;
3545    session.peer.send(
3546        session.connection_id,
3547        build_channels_update(channels_for_user),
3548    )?;
3549    Ok(())
3550}
3551
3552/// Creates a new channel.
3553async fn create_channel(
3554    request: proto::CreateChannel,
3555    response: Response<proto::CreateChannel>,
3556    session: UserSession,
3557) -> Result<()> {
3558    let db = session.db().await;
3559
3560    let parent_id = request.parent_id.map(ChannelId::from_proto);
3561    let (channel, membership) = db
3562        .create_channel(&request.name, parent_id, session.user_id())
3563        .await?;
3564
3565    let root_id = channel.root_id();
3566    let channel = Channel::from_model(channel);
3567
3568    response.send(proto::CreateChannelResponse {
3569        channel: Some(channel.to_proto()),
3570        parent_id: request.parent_id,
3571    })?;
3572
3573    let mut connection_pool = session.connection_pool().await;
3574    if let Some(membership) = membership {
3575        connection_pool.subscribe_to_channel(
3576            membership.user_id,
3577            membership.channel_id,
3578            membership.role,
3579        );
3580        let update = proto::UpdateUserChannels {
3581            channel_memberships: vec![proto::ChannelMembership {
3582                channel_id: membership.channel_id.to_proto(),
3583                role: membership.role.into(),
3584            }],
3585            ..Default::default()
3586        };
3587        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3588            session.peer.send(connection_id, update.clone())?;
3589        }
3590    }
3591
3592    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3593        if !role.can_see_channel(channel.visibility) {
3594            continue;
3595        }
3596
3597        let update = proto::UpdateChannels {
3598            channels: vec![channel.to_proto()],
3599            ..Default::default()
3600        };
3601        session.peer.send(connection_id, update.clone())?;
3602    }
3603
3604    Ok(())
3605}
3606
3607/// Delete a channel
3608async fn delete_channel(
3609    request: proto::DeleteChannel,
3610    response: Response<proto::DeleteChannel>,
3611    session: UserSession,
3612) -> Result<()> {
3613    let db = session.db().await;
3614
3615    let channel_id = request.channel_id;
3616    let (root_channel, removed_channels) = db
3617        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3618        .await?;
3619    response.send(proto::Ack {})?;
3620
3621    // Notify members of removed channels
3622    let mut update = proto::UpdateChannels::default();
3623    update
3624        .delete_channels
3625        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3626
3627    let connection_pool = session.connection_pool().await;
3628    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3629        session.peer.send(connection_id, update.clone())?;
3630    }
3631
3632    Ok(())
3633}
3634
3635/// Invite someone to join a channel.
3636async fn invite_channel_member(
3637    request: proto::InviteChannelMember,
3638    response: Response<proto::InviteChannelMember>,
3639    session: UserSession,
3640) -> Result<()> {
3641    let db = session.db().await;
3642    let channel_id = ChannelId::from_proto(request.channel_id);
3643    let invitee_id = UserId::from_proto(request.user_id);
3644    let InviteMemberResult {
3645        channel,
3646        notifications,
3647    } = db
3648        .invite_channel_member(
3649            channel_id,
3650            invitee_id,
3651            session.user_id(),
3652            request.role().into(),
3653        )
3654        .await?;
3655
3656    let update = proto::UpdateChannels {
3657        channel_invitations: vec![channel.to_proto()],
3658        ..Default::default()
3659    };
3660
3661    let connection_pool = session.connection_pool().await;
3662    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3663        session.peer.send(connection_id, update.clone())?;
3664    }
3665
3666    send_notifications(&connection_pool, &session.peer, notifications);
3667
3668    response.send(proto::Ack {})?;
3669    Ok(())
3670}
3671
3672/// remove someone from a channel
3673async fn remove_channel_member(
3674    request: proto::RemoveChannelMember,
3675    response: Response<proto::RemoveChannelMember>,
3676    session: UserSession,
3677) -> Result<()> {
3678    let db = session.db().await;
3679    let channel_id = ChannelId::from_proto(request.channel_id);
3680    let member_id = UserId::from_proto(request.user_id);
3681
3682    let RemoveChannelMemberResult {
3683        membership_update,
3684        notification_id,
3685    } = db
3686        .remove_channel_member(channel_id, member_id, session.user_id())
3687        .await?;
3688
3689    let mut connection_pool = session.connection_pool().await;
3690    notify_membership_updated(
3691        &mut connection_pool,
3692        membership_update,
3693        member_id,
3694        &session.peer,
3695    );
3696    for connection_id in connection_pool.user_connection_ids(member_id) {
3697        if let Some(notification_id) = notification_id {
3698            session
3699                .peer
3700                .send(
3701                    connection_id,
3702                    proto::DeleteNotification {
3703                        notification_id: notification_id.to_proto(),
3704                    },
3705                )
3706                .trace_err();
3707        }
3708    }
3709
3710    response.send(proto::Ack {})?;
3711    Ok(())
3712}
3713
3714/// Toggle the channel between public and private.
3715/// Care is taken to maintain the invariant that public channels only descend from public channels,
3716/// (though members-only channels can appear at any point in the hierarchy).
3717async fn set_channel_visibility(
3718    request: proto::SetChannelVisibility,
3719    response: Response<proto::SetChannelVisibility>,
3720    session: UserSession,
3721) -> Result<()> {
3722    let db = session.db().await;
3723    let channel_id = ChannelId::from_proto(request.channel_id);
3724    let visibility = request.visibility().into();
3725
3726    let channel_model = db
3727        .set_channel_visibility(channel_id, visibility, session.user_id())
3728        .await?;
3729    let root_id = channel_model.root_id();
3730    let channel = Channel::from_model(channel_model);
3731
3732    let mut connection_pool = session.connection_pool().await;
3733    for (user_id, role) in connection_pool
3734        .channel_user_ids(root_id)
3735        .collect::<Vec<_>>()
3736        .into_iter()
3737    {
3738        let update = if role.can_see_channel(channel.visibility) {
3739            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3740            proto::UpdateChannels {
3741                channels: vec![channel.to_proto()],
3742                ..Default::default()
3743            }
3744        } else {
3745            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3746            proto::UpdateChannels {
3747                delete_channels: vec![channel.id.to_proto()],
3748                ..Default::default()
3749            }
3750        };
3751
3752        for connection_id in connection_pool.user_connection_ids(user_id) {
3753            session.peer.send(connection_id, update.clone())?;
3754        }
3755    }
3756
3757    response.send(proto::Ack {})?;
3758    Ok(())
3759}
3760
3761/// Alter the role for a user in the channel.
3762async fn set_channel_member_role(
3763    request: proto::SetChannelMemberRole,
3764    response: Response<proto::SetChannelMemberRole>,
3765    session: UserSession,
3766) -> Result<()> {
3767    let db = session.db().await;
3768    let channel_id = ChannelId::from_proto(request.channel_id);
3769    let member_id = UserId::from_proto(request.user_id);
3770    let result = db
3771        .set_channel_member_role(
3772            channel_id,
3773            session.user_id(),
3774            member_id,
3775            request.role().into(),
3776        )
3777        .await?;
3778
3779    match result {
3780        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3781            let mut connection_pool = session.connection_pool().await;
3782            notify_membership_updated(
3783                &mut connection_pool,
3784                membership_update,
3785                member_id,
3786                &session.peer,
3787            )
3788        }
3789        db::SetMemberRoleResult::InviteUpdated(channel) => {
3790            let update = proto::UpdateChannels {
3791                channel_invitations: vec![channel.to_proto()],
3792                ..Default::default()
3793            };
3794
3795            for connection_id in session
3796                .connection_pool()
3797                .await
3798                .user_connection_ids(member_id)
3799            {
3800                session.peer.send(connection_id, update.clone())?;
3801            }
3802        }
3803    }
3804
3805    response.send(proto::Ack {})?;
3806    Ok(())
3807}
3808
3809/// Change the name of a channel
3810async fn rename_channel(
3811    request: proto::RenameChannel,
3812    response: Response<proto::RenameChannel>,
3813    session: UserSession,
3814) -> Result<()> {
3815    let db = session.db().await;
3816    let channel_id = ChannelId::from_proto(request.channel_id);
3817    let channel_model = db
3818        .rename_channel(channel_id, session.user_id(), &request.name)
3819        .await?;
3820    let root_id = channel_model.root_id();
3821    let channel = Channel::from_model(channel_model);
3822
3823    response.send(proto::RenameChannelResponse {
3824        channel: Some(channel.to_proto()),
3825    })?;
3826
3827    let connection_pool = session.connection_pool().await;
3828    let update = proto::UpdateChannels {
3829        channels: vec![channel.to_proto()],
3830        ..Default::default()
3831    };
3832    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3833        if role.can_see_channel(channel.visibility) {
3834            session.peer.send(connection_id, update.clone())?;
3835        }
3836    }
3837
3838    Ok(())
3839}
3840
3841/// Move a channel to a new parent.
3842async fn move_channel(
3843    request: proto::MoveChannel,
3844    response: Response<proto::MoveChannel>,
3845    session: UserSession,
3846) -> Result<()> {
3847    let channel_id = ChannelId::from_proto(request.channel_id);
3848    let to = ChannelId::from_proto(request.to);
3849
3850    let (root_id, channels) = session
3851        .db()
3852        .await
3853        .move_channel(channel_id, to, session.user_id())
3854        .await?;
3855
3856    let connection_pool = session.connection_pool().await;
3857    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3858        let channels = channels
3859            .iter()
3860            .filter_map(|channel| {
3861                if role.can_see_channel(channel.visibility) {
3862                    Some(channel.to_proto())
3863                } else {
3864                    None
3865                }
3866            })
3867            .collect::<Vec<_>>();
3868        if channels.is_empty() {
3869            continue;
3870        }
3871
3872        let update = proto::UpdateChannels {
3873            channels,
3874            ..Default::default()
3875        };
3876
3877        session.peer.send(connection_id, update.clone())?;
3878    }
3879
3880    response.send(Ack {})?;
3881    Ok(())
3882}
3883
3884/// Get the list of channel members
3885async fn get_channel_members(
3886    request: proto::GetChannelMembers,
3887    response: Response<proto::GetChannelMembers>,
3888    session: UserSession,
3889) -> Result<()> {
3890    let db = session.db().await;
3891    let channel_id = ChannelId::from_proto(request.channel_id);
3892    let limit = if request.limit == 0 {
3893        u16::MAX as u64
3894    } else {
3895        request.limit
3896    };
3897    let (members, users) = db
3898        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3899        .await?;
3900    response.send(proto::GetChannelMembersResponse { members, users })?;
3901    Ok(())
3902}
3903
3904/// Accept or decline a channel invitation.
3905async fn respond_to_channel_invite(
3906    request: proto::RespondToChannelInvite,
3907    response: Response<proto::RespondToChannelInvite>,
3908    session: UserSession,
3909) -> Result<()> {
3910    let db = session.db().await;
3911    let channel_id = ChannelId::from_proto(request.channel_id);
3912    let RespondToChannelInvite {
3913        membership_update,
3914        notifications,
3915    } = db
3916        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3917        .await?;
3918
3919    let mut connection_pool = session.connection_pool().await;
3920    if let Some(membership_update) = membership_update {
3921        notify_membership_updated(
3922            &mut connection_pool,
3923            membership_update,
3924            session.user_id(),
3925            &session.peer,
3926        );
3927    } else {
3928        let update = proto::UpdateChannels {
3929            remove_channel_invitations: vec![channel_id.to_proto()],
3930            ..Default::default()
3931        };
3932
3933        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3934            session.peer.send(connection_id, update.clone())?;
3935        }
3936    };
3937
3938    send_notifications(&connection_pool, &session.peer, notifications);
3939
3940    response.send(proto::Ack {})?;
3941
3942    Ok(())
3943}
3944
3945/// Join the channels' room
3946async fn join_channel(
3947    request: proto::JoinChannel,
3948    response: Response<proto::JoinChannel>,
3949    session: UserSession,
3950) -> Result<()> {
3951    let channel_id = ChannelId::from_proto(request.channel_id);
3952    join_channel_internal(channel_id, Box::new(response), session).await
3953}
3954
3955trait JoinChannelInternalResponse {
3956    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3957}
3958impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3959    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3960        Response::<proto::JoinChannel>::send(self, result)
3961    }
3962}
3963impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3964    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3965        Response::<proto::JoinRoom>::send(self, result)
3966    }
3967}
3968
3969async fn join_channel_internal(
3970    channel_id: ChannelId,
3971    response: Box<impl JoinChannelInternalResponse>,
3972    session: UserSession,
3973) -> Result<()> {
3974    let joined_room = {
3975        let mut db = session.db().await;
3976        // If zed quits without leaving the room, and the user re-opens zed before the
3977        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3978        // room they were in.
3979        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3980            tracing::info!(
3981                stale_connection_id = %connection,
3982                "cleaning up stale connection",
3983            );
3984            drop(db);
3985            leave_room_for_session(&session, connection).await?;
3986            db = session.db().await;
3987        }
3988
3989        let (joined_room, membership_updated, role) = db
3990            .join_channel(channel_id, session.user_id(), session.connection_id)
3991            .await?;
3992
3993        let live_kit_connection_info =
3994            session
3995                .app_state
3996                .live_kit_client
3997                .as_ref()
3998                .and_then(|live_kit| {
3999                    let (can_publish, token) = if role == ChannelRole::Guest {
4000                        (
4001                            false,
4002                            live_kit
4003                                .guest_token(
4004                                    &joined_room.room.live_kit_room,
4005                                    &session.user_id().to_string(),
4006                                )
4007                                .trace_err()?,
4008                        )
4009                    } else {
4010                        (
4011                            true,
4012                            live_kit
4013                                .room_token(
4014                                    &joined_room.room.live_kit_room,
4015                                    &session.user_id().to_string(),
4016                                )
4017                                .trace_err()?,
4018                        )
4019                    };
4020
4021                    Some(LiveKitConnectionInfo {
4022                        server_url: live_kit.url().into(),
4023                        token,
4024                        can_publish,
4025                    })
4026                });
4027
4028        response.send(proto::JoinRoomResponse {
4029            room: Some(joined_room.room.clone()),
4030            channel_id: joined_room
4031                .channel
4032                .as_ref()
4033                .map(|channel| channel.id.to_proto()),
4034            live_kit_connection_info,
4035        })?;
4036
4037        let mut connection_pool = session.connection_pool().await;
4038        if let Some(membership_updated) = membership_updated {
4039            notify_membership_updated(
4040                &mut connection_pool,
4041                membership_updated,
4042                session.user_id(),
4043                &session.peer,
4044            );
4045        }
4046
4047        room_updated(&joined_room.room, &session.peer);
4048
4049        joined_room
4050    };
4051
4052    channel_updated(
4053        &joined_room
4054            .channel
4055            .ok_or_else(|| anyhow!("channel not returned"))?,
4056        &joined_room.room,
4057        &session.peer,
4058        &*session.connection_pool().await,
4059    );
4060
4061    update_user_contacts(session.user_id(), &session).await?;
4062    Ok(())
4063}
4064
4065/// Start editing the channel notes
4066async fn join_channel_buffer(
4067    request: proto::JoinChannelBuffer,
4068    response: Response<proto::JoinChannelBuffer>,
4069    session: UserSession,
4070) -> Result<()> {
4071    let db = session.db().await;
4072    let channel_id = ChannelId::from_proto(request.channel_id);
4073
4074    let open_response = db
4075        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4076        .await?;
4077
4078    let collaborators = open_response.collaborators.clone();
4079    response.send(open_response)?;
4080
4081    let update = UpdateChannelBufferCollaborators {
4082        channel_id: channel_id.to_proto(),
4083        collaborators: collaborators.clone(),
4084    };
4085    channel_buffer_updated(
4086        session.connection_id,
4087        collaborators
4088            .iter()
4089            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4090        &update,
4091        &session.peer,
4092    );
4093
4094    Ok(())
4095}
4096
4097/// Edit the channel notes
4098async fn update_channel_buffer(
4099    request: proto::UpdateChannelBuffer,
4100    session: UserSession,
4101) -> Result<()> {
4102    let db = session.db().await;
4103    let channel_id = ChannelId::from_proto(request.channel_id);
4104
4105    let (collaborators, epoch, version) = db
4106        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4107        .await?;
4108
4109    channel_buffer_updated(
4110        session.connection_id,
4111        collaborators.clone(),
4112        &proto::UpdateChannelBuffer {
4113            channel_id: channel_id.to_proto(),
4114            operations: request.operations,
4115        },
4116        &session.peer,
4117    );
4118
4119    let pool = &*session.connection_pool().await;
4120
4121    let non_collaborators =
4122        pool.channel_connection_ids(channel_id)
4123            .filter_map(|(connection_id, _)| {
4124                if collaborators.contains(&connection_id) {
4125                    None
4126                } else {
4127                    Some(connection_id)
4128                }
4129            });
4130
4131    broadcast(None, non_collaborators, |peer_id| {
4132        session.peer.send(
4133            peer_id,
4134            proto::UpdateChannels {
4135                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4136                    channel_id: channel_id.to_proto(),
4137                    epoch: epoch as u64,
4138                    version: version.clone(),
4139                }],
4140                ..Default::default()
4141            },
4142        )
4143    });
4144
4145    Ok(())
4146}
4147
4148/// Rejoin the channel notes after a connection blip
4149async fn rejoin_channel_buffers(
4150    request: proto::RejoinChannelBuffers,
4151    response: Response<proto::RejoinChannelBuffers>,
4152    session: UserSession,
4153) -> Result<()> {
4154    let db = session.db().await;
4155    let buffers = db
4156        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4157        .await?;
4158
4159    for rejoined_buffer in &buffers {
4160        let collaborators_to_notify = rejoined_buffer
4161            .buffer
4162            .collaborators
4163            .iter()
4164            .filter_map(|c| Some(c.peer_id?.into()));
4165        channel_buffer_updated(
4166            session.connection_id,
4167            collaborators_to_notify,
4168            &proto::UpdateChannelBufferCollaborators {
4169                channel_id: rejoined_buffer.buffer.channel_id,
4170                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4171            },
4172            &session.peer,
4173        );
4174    }
4175
4176    response.send(proto::RejoinChannelBuffersResponse {
4177        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4178    })?;
4179
4180    Ok(())
4181}
4182
4183/// Stop editing the channel notes
4184async fn leave_channel_buffer(
4185    request: proto::LeaveChannelBuffer,
4186    response: Response<proto::LeaveChannelBuffer>,
4187    session: UserSession,
4188) -> Result<()> {
4189    let db = session.db().await;
4190    let channel_id = ChannelId::from_proto(request.channel_id);
4191
4192    let left_buffer = db
4193        .leave_channel_buffer(channel_id, session.connection_id)
4194        .await?;
4195
4196    response.send(Ack {})?;
4197
4198    channel_buffer_updated(
4199        session.connection_id,
4200        left_buffer.connections,
4201        &proto::UpdateChannelBufferCollaborators {
4202            channel_id: channel_id.to_proto(),
4203            collaborators: left_buffer.collaborators,
4204        },
4205        &session.peer,
4206    );
4207
4208    Ok(())
4209}
4210
4211fn channel_buffer_updated<T: EnvelopedMessage>(
4212    sender_id: ConnectionId,
4213    collaborators: impl IntoIterator<Item = ConnectionId>,
4214    message: &T,
4215    peer: &Peer,
4216) {
4217    broadcast(Some(sender_id), collaborators, |peer_id| {
4218        peer.send(peer_id, message.clone())
4219    });
4220}
4221
4222fn send_notifications(
4223    connection_pool: &ConnectionPool,
4224    peer: &Peer,
4225    notifications: db::NotificationBatch,
4226) {
4227    for (user_id, notification) in notifications {
4228        for connection_id in connection_pool.user_connection_ids(user_id) {
4229            if let Err(error) = peer.send(
4230                connection_id,
4231                proto::AddNotification {
4232                    notification: Some(notification.clone()),
4233                },
4234            ) {
4235                tracing::error!(
4236                    "failed to send notification to {:?} {}",
4237                    connection_id,
4238                    error
4239                );
4240            }
4241        }
4242    }
4243}
4244
4245/// Send a message to the channel
4246async fn send_channel_message(
4247    request: proto::SendChannelMessage,
4248    response: Response<proto::SendChannelMessage>,
4249    session: UserSession,
4250) -> Result<()> {
4251    // Validate the message body.
4252    let body = request.body.trim().to_string();
4253    if body.len() > MAX_MESSAGE_LEN {
4254        return Err(anyhow!("message is too long"))?;
4255    }
4256    if body.is_empty() {
4257        return Err(anyhow!("message can't be blank"))?;
4258    }
4259
4260    // TODO: adjust mentions if body is trimmed
4261
4262    let timestamp = OffsetDateTime::now_utc();
4263    let nonce = request
4264        .nonce
4265        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4266
4267    let channel_id = ChannelId::from_proto(request.channel_id);
4268    let CreatedChannelMessage {
4269        message_id,
4270        participant_connection_ids,
4271        notifications,
4272    } = session
4273        .db()
4274        .await
4275        .create_channel_message(
4276            channel_id,
4277            session.user_id(),
4278            &body,
4279            &request.mentions,
4280            timestamp,
4281            nonce.clone().into(),
4282            request.reply_to_message_id.map(MessageId::from_proto),
4283        )
4284        .await?;
4285
4286    let message = proto::ChannelMessage {
4287        sender_id: session.user_id().to_proto(),
4288        id: message_id.to_proto(),
4289        body,
4290        mentions: request.mentions,
4291        timestamp: timestamp.unix_timestamp() as u64,
4292        nonce: Some(nonce),
4293        reply_to_message_id: request.reply_to_message_id,
4294        edited_at: None,
4295    };
4296    broadcast(
4297        Some(session.connection_id),
4298        participant_connection_ids.clone(),
4299        |connection| {
4300            session.peer.send(
4301                connection,
4302                proto::ChannelMessageSent {
4303                    channel_id: channel_id.to_proto(),
4304                    message: Some(message.clone()),
4305                },
4306            )
4307        },
4308    );
4309    response.send(proto::SendChannelMessageResponse {
4310        message: Some(message),
4311    })?;
4312
4313    let pool = &*session.connection_pool().await;
4314    let non_participants =
4315        pool.channel_connection_ids(channel_id)
4316            .filter_map(|(connection_id, _)| {
4317                if participant_connection_ids.contains(&connection_id) {
4318                    None
4319                } else {
4320                    Some(connection_id)
4321                }
4322            });
4323    broadcast(None, non_participants, |peer_id| {
4324        session.peer.send(
4325            peer_id,
4326            proto::UpdateChannels {
4327                latest_channel_message_ids: vec![proto::ChannelMessageId {
4328                    channel_id: channel_id.to_proto(),
4329                    message_id: message_id.to_proto(),
4330                }],
4331                ..Default::default()
4332            },
4333        )
4334    });
4335    send_notifications(pool, &session.peer, notifications);
4336
4337    Ok(())
4338}
4339
4340/// Delete a channel message
4341async fn remove_channel_message(
4342    request: proto::RemoveChannelMessage,
4343    response: Response<proto::RemoveChannelMessage>,
4344    session: UserSession,
4345) -> Result<()> {
4346    let channel_id = ChannelId::from_proto(request.channel_id);
4347    let message_id = MessageId::from_proto(request.message_id);
4348    let (connection_ids, existing_notification_ids) = session
4349        .db()
4350        .await
4351        .remove_channel_message(channel_id, message_id, session.user_id())
4352        .await?;
4353
4354    broadcast(
4355        Some(session.connection_id),
4356        connection_ids,
4357        move |connection| {
4358            session.peer.send(connection, request.clone())?;
4359
4360            for notification_id in &existing_notification_ids {
4361                session.peer.send(
4362                    connection,
4363                    proto::DeleteNotification {
4364                        notification_id: (*notification_id).to_proto(),
4365                    },
4366                )?;
4367            }
4368
4369            Ok(())
4370        },
4371    );
4372    response.send(proto::Ack {})?;
4373    Ok(())
4374}
4375
4376async fn update_channel_message(
4377    request: proto::UpdateChannelMessage,
4378    response: Response<proto::UpdateChannelMessage>,
4379    session: UserSession,
4380) -> Result<()> {
4381    let channel_id = ChannelId::from_proto(request.channel_id);
4382    let message_id = MessageId::from_proto(request.message_id);
4383    let updated_at = OffsetDateTime::now_utc();
4384    let UpdatedChannelMessage {
4385        message_id,
4386        participant_connection_ids,
4387        notifications,
4388        reply_to_message_id,
4389        timestamp,
4390        deleted_mention_notification_ids,
4391        updated_mention_notifications,
4392    } = session
4393        .db()
4394        .await
4395        .update_channel_message(
4396            channel_id,
4397            message_id,
4398            session.user_id(),
4399            request.body.as_str(),
4400            &request.mentions,
4401            updated_at,
4402        )
4403        .await?;
4404
4405    let nonce = request
4406        .nonce
4407        .clone()
4408        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4409
4410    let message = proto::ChannelMessage {
4411        sender_id: session.user_id().to_proto(),
4412        id: message_id.to_proto(),
4413        body: request.body.clone(),
4414        mentions: request.mentions.clone(),
4415        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4416        nonce: Some(nonce),
4417        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4418        edited_at: Some(updated_at.unix_timestamp() as u64),
4419    };
4420
4421    response.send(proto::Ack {})?;
4422
4423    let pool = &*session.connection_pool().await;
4424    broadcast(
4425        Some(session.connection_id),
4426        participant_connection_ids,
4427        |connection| {
4428            session.peer.send(
4429                connection,
4430                proto::ChannelMessageUpdate {
4431                    channel_id: channel_id.to_proto(),
4432                    message: Some(message.clone()),
4433                },
4434            )?;
4435
4436            for notification_id in &deleted_mention_notification_ids {
4437                session.peer.send(
4438                    connection,
4439                    proto::DeleteNotification {
4440                        notification_id: (*notification_id).to_proto(),
4441                    },
4442                )?;
4443            }
4444
4445            for notification in &updated_mention_notifications {
4446                session.peer.send(
4447                    connection,
4448                    proto::UpdateNotification {
4449                        notification: Some(notification.clone()),
4450                    },
4451                )?;
4452            }
4453
4454            Ok(())
4455        },
4456    );
4457
4458    send_notifications(pool, &session.peer, notifications);
4459
4460    Ok(())
4461}
4462
4463/// Mark a channel message as read
4464async fn acknowledge_channel_message(
4465    request: proto::AckChannelMessage,
4466    session: UserSession,
4467) -> Result<()> {
4468    let channel_id = ChannelId::from_proto(request.channel_id);
4469    let message_id = MessageId::from_proto(request.message_id);
4470    let notifications = session
4471        .db()
4472        .await
4473        .observe_channel_message(channel_id, session.user_id(), message_id)
4474        .await?;
4475    send_notifications(
4476        &*session.connection_pool().await,
4477        &session.peer,
4478        notifications,
4479    );
4480    Ok(())
4481}
4482
4483/// Mark a buffer version as synced
4484async fn acknowledge_buffer_version(
4485    request: proto::AckBufferOperation,
4486    session: UserSession,
4487) -> Result<()> {
4488    let buffer_id = BufferId::from_proto(request.buffer_id);
4489    session
4490        .db()
4491        .await
4492        .observe_buffer_version(
4493            buffer_id,
4494            session.user_id(),
4495            request.epoch as i32,
4496            &request.version,
4497        )
4498        .await?;
4499    Ok(())
4500}
4501
4502async fn count_language_model_tokens(
4503    request: proto::CountLanguageModelTokens,
4504    response: Response<proto::CountLanguageModelTokens>,
4505    session: Session,
4506    config: &Config,
4507) -> Result<()> {
4508    let Some(session) = session.for_user() else {
4509        return Err(anyhow!("user not found"))?;
4510    };
4511    authorize_access_to_legacy_llm_endpoints(&session).await?;
4512
4513    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4514        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4515        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4516    };
4517
4518    session
4519        .app_state
4520        .rate_limiter
4521        .check(&*rate_limit, session.user_id())
4522        .await?;
4523
4524    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4525        Some(proto::LanguageModelProvider::Google) => {
4526            let api_key = config
4527                .google_ai_api_key
4528                .as_ref()
4529                .context("no Google AI API key configured on the server")?;
4530            google_ai::count_tokens(
4531                session.http_client.as_ref(),
4532                google_ai::API_URL,
4533                api_key,
4534                serde_json::from_str(&request.request)?,
4535                None,
4536            )
4537            .await?
4538        }
4539        _ => return Err(anyhow!("unsupported provider"))?,
4540    };
4541
4542    response.send(proto::CountLanguageModelTokensResponse {
4543        token_count: result.total_tokens as u32,
4544    })?;
4545
4546    Ok(())
4547}
4548
4549struct ZedProCountLanguageModelTokensRateLimit;
4550
4551impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4552    fn capacity(&self) -> usize {
4553        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4554            .ok()
4555            .and_then(|v| v.parse().ok())
4556            .unwrap_or(600) // Picked arbitrarily
4557    }
4558
4559    fn refill_duration(&self) -> chrono::Duration {
4560        chrono::Duration::hours(1)
4561    }
4562
4563    fn db_name(&self) -> &'static str {
4564        "zed-pro:count-language-model-tokens"
4565    }
4566}
4567
4568struct FreeCountLanguageModelTokensRateLimit;
4569
4570impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4571    fn capacity(&self) -> usize {
4572        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4573            .ok()
4574            .and_then(|v| v.parse().ok())
4575            .unwrap_or(600 / 10) // Picked arbitrarily
4576    }
4577
4578    fn refill_duration(&self) -> chrono::Duration {
4579        chrono::Duration::hours(1)
4580    }
4581
4582    fn db_name(&self) -> &'static str {
4583        "free:count-language-model-tokens"
4584    }
4585}
4586
4587struct ZedProComputeEmbeddingsRateLimit;
4588
4589impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4590    fn capacity(&self) -> usize {
4591        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4592            .ok()
4593            .and_then(|v| v.parse().ok())
4594            .unwrap_or(5000) // Picked arbitrarily
4595    }
4596
4597    fn refill_duration(&self) -> chrono::Duration {
4598        chrono::Duration::hours(1)
4599    }
4600
4601    fn db_name(&self) -> &'static str {
4602        "zed-pro:compute-embeddings"
4603    }
4604}
4605
4606struct FreeComputeEmbeddingsRateLimit;
4607
4608impl RateLimit for FreeComputeEmbeddingsRateLimit {
4609    fn capacity(&self) -> usize {
4610        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4611            .ok()
4612            .and_then(|v| v.parse().ok())
4613            .unwrap_or(5000 / 10) // Picked arbitrarily
4614    }
4615
4616    fn refill_duration(&self) -> chrono::Duration {
4617        chrono::Duration::hours(1)
4618    }
4619
4620    fn db_name(&self) -> &'static str {
4621        "free:compute-embeddings"
4622    }
4623}
4624
4625async fn compute_embeddings(
4626    request: proto::ComputeEmbeddings,
4627    response: Response<proto::ComputeEmbeddings>,
4628    session: UserSession,
4629    api_key: Option<Arc<str>>,
4630) -> Result<()> {
4631    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4632    authorize_access_to_legacy_llm_endpoints(&session).await?;
4633
4634    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4635        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4636        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4637    };
4638
4639    session
4640        .app_state
4641        .rate_limiter
4642        .check(&*rate_limit, session.user_id())
4643        .await?;
4644
4645    let embeddings = match request.model.as_str() {
4646        "openai/text-embedding-3-small" => {
4647            open_ai::embed(
4648                session.http_client.as_ref(),
4649                OPEN_AI_API_URL,
4650                &api_key,
4651                OpenAiEmbeddingModel::TextEmbedding3Small,
4652                request.texts.iter().map(|text| text.as_str()),
4653            )
4654            .await?
4655        }
4656        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4657    };
4658
4659    let embeddings = request
4660        .texts
4661        .iter()
4662        .map(|text| {
4663            let mut hasher = sha2::Sha256::new();
4664            hasher.update(text.as_bytes());
4665            let result = hasher.finalize();
4666            result.to_vec()
4667        })
4668        .zip(
4669            embeddings
4670                .data
4671                .into_iter()
4672                .map(|embedding| embedding.embedding),
4673        )
4674        .collect::<HashMap<_, _>>();
4675
4676    let db = session.db().await;
4677    db.save_embeddings(&request.model, &embeddings)
4678        .await
4679        .context("failed to save embeddings")
4680        .trace_err();
4681
4682    response.send(proto::ComputeEmbeddingsResponse {
4683        embeddings: embeddings
4684            .into_iter()
4685            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4686            .collect(),
4687    })?;
4688    Ok(())
4689}
4690
4691async fn get_cached_embeddings(
4692    request: proto::GetCachedEmbeddings,
4693    response: Response<proto::GetCachedEmbeddings>,
4694    session: UserSession,
4695) -> Result<()> {
4696    authorize_access_to_legacy_llm_endpoints(&session).await?;
4697
4698    let db = session.db().await;
4699    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4700
4701    response.send(proto::GetCachedEmbeddingsResponse {
4702        embeddings: embeddings
4703            .into_iter()
4704            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4705            .collect(),
4706    })?;
4707    Ok(())
4708}
4709
4710/// This is leftover from before the LLM service.
4711///
4712/// The endpoints protected by this check will be moved there eventually.
4713async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4714    if session.is_staff() {
4715        Ok(())
4716    } else {
4717        Err(anyhow!("permission denied"))?
4718    }
4719}
4720
4721/// Get a Supermaven API key for the user
4722async fn get_supermaven_api_key(
4723    _request: proto::GetSupermavenApiKey,
4724    response: Response<proto::GetSupermavenApiKey>,
4725    session: UserSession,
4726) -> Result<()> {
4727    let user_id: String = session.user_id().to_string();
4728    if !session.is_staff() {
4729        return Err(anyhow!("supermaven not enabled for this account"))?;
4730    }
4731
4732    let email = session
4733        .email()
4734        .ok_or_else(|| anyhow!("user must have an email"))?;
4735
4736    let supermaven_admin_api = session
4737        .supermaven_client
4738        .as_ref()
4739        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4740
4741    let result = supermaven_admin_api
4742        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4743        .await?;
4744
4745    response.send(proto::GetSupermavenApiKeyResponse {
4746        api_key: result.api_key,
4747    })?;
4748
4749    Ok(())
4750}
4751
4752/// Start receiving chat updates for a channel
4753async fn join_channel_chat(
4754    request: proto::JoinChannelChat,
4755    response: Response<proto::JoinChannelChat>,
4756    session: UserSession,
4757) -> Result<()> {
4758    let channel_id = ChannelId::from_proto(request.channel_id);
4759
4760    let db = session.db().await;
4761    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4762        .await?;
4763    let messages = db
4764        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4765        .await?;
4766    response.send(proto::JoinChannelChatResponse {
4767        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4768        messages,
4769    })?;
4770    Ok(())
4771}
4772
4773/// Stop receiving chat updates for a channel
4774async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4775    let channel_id = ChannelId::from_proto(request.channel_id);
4776    session
4777        .db()
4778        .await
4779        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4780        .await?;
4781    Ok(())
4782}
4783
4784/// Retrieve the chat history for a channel
4785async fn get_channel_messages(
4786    request: proto::GetChannelMessages,
4787    response: Response<proto::GetChannelMessages>,
4788    session: UserSession,
4789) -> Result<()> {
4790    let channel_id = ChannelId::from_proto(request.channel_id);
4791    let messages = session
4792        .db()
4793        .await
4794        .get_channel_messages(
4795            channel_id,
4796            session.user_id(),
4797            MESSAGE_COUNT_PER_PAGE,
4798            Some(MessageId::from_proto(request.before_message_id)),
4799        )
4800        .await?;
4801    response.send(proto::GetChannelMessagesResponse {
4802        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4803        messages,
4804    })?;
4805    Ok(())
4806}
4807
4808/// Retrieve specific chat messages
4809async fn get_channel_messages_by_id(
4810    request: proto::GetChannelMessagesById,
4811    response: Response<proto::GetChannelMessagesById>,
4812    session: UserSession,
4813) -> Result<()> {
4814    let message_ids = request
4815        .message_ids
4816        .iter()
4817        .map(|id| MessageId::from_proto(*id))
4818        .collect::<Vec<_>>();
4819    let messages = session
4820        .db()
4821        .await
4822        .get_channel_messages_by_id(session.user_id(), &message_ids)
4823        .await?;
4824    response.send(proto::GetChannelMessagesResponse {
4825        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4826        messages,
4827    })?;
4828    Ok(())
4829}
4830
4831/// Retrieve the current users notifications
4832async fn get_notifications(
4833    request: proto::GetNotifications,
4834    response: Response<proto::GetNotifications>,
4835    session: UserSession,
4836) -> Result<()> {
4837    let notifications = session
4838        .db()
4839        .await
4840        .get_notifications(
4841            session.user_id(),
4842            NOTIFICATION_COUNT_PER_PAGE,
4843            request.before_id.map(db::NotificationId::from_proto),
4844        )
4845        .await?;
4846    response.send(proto::GetNotificationsResponse {
4847        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4848        notifications,
4849    })?;
4850    Ok(())
4851}
4852
4853/// Mark notifications as read
4854async fn mark_notification_as_read(
4855    request: proto::MarkNotificationRead,
4856    response: Response<proto::MarkNotificationRead>,
4857    session: UserSession,
4858) -> Result<()> {
4859    let database = &session.db().await;
4860    let notifications = database
4861        .mark_notification_as_read_by_id(
4862            session.user_id(),
4863            NotificationId::from_proto(request.notification_id),
4864        )
4865        .await?;
4866    send_notifications(
4867        &*session.connection_pool().await,
4868        &session.peer,
4869        notifications,
4870    );
4871    response.send(proto::Ack {})?;
4872    Ok(())
4873}
4874
4875/// Get the current users information
4876async fn get_private_user_info(
4877    _request: proto::GetPrivateUserInfo,
4878    response: Response<proto::GetPrivateUserInfo>,
4879    session: UserSession,
4880) -> Result<()> {
4881    let db = session.db().await;
4882
4883    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4884    let user = db
4885        .get_user_by_id(session.user_id())
4886        .await?
4887        .ok_or_else(|| anyhow!("user not found"))?;
4888    let flags = db.get_user_flags(session.user_id()).await?;
4889
4890    response.send(proto::GetPrivateUserInfoResponse {
4891        metrics_id,
4892        staff: user.admin,
4893        flags,
4894        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4895    })?;
4896    Ok(())
4897}
4898
4899/// Accept the terms of service (tos) on behalf of the current user
4900async fn accept_terms_of_service(
4901    _request: proto::AcceptTermsOfService,
4902    response: Response<proto::AcceptTermsOfService>,
4903    session: UserSession,
4904) -> Result<()> {
4905    let db = session.db().await;
4906
4907    let accepted_tos_at = Utc::now();
4908    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4909        .await?;
4910
4911    response.send(proto::AcceptTermsOfServiceResponse {
4912        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4913    })?;
4914    Ok(())
4915}
4916
4917/// The minimum account age an account must have in order to use the LLM service.
4918const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4919
4920async fn get_llm_api_token(
4921    _request: proto::GetLlmToken,
4922    response: Response<proto::GetLlmToken>,
4923    session: UserSession,
4924) -> Result<()> {
4925    let db = session.db().await;
4926
4927    let flags = db.get_user_flags(session.user_id()).await?;
4928    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4929    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4930
4931    if !session.is_staff() && !has_language_models_feature_flag {
4932        Err(anyhow!("permission denied"))?
4933    }
4934
4935    let user_id = session.user_id();
4936    let user = db
4937        .get_user_by_id(user_id)
4938        .await?
4939        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4940
4941    if user.accepted_tos_at.is_none() {
4942        Err(anyhow!("terms of service not accepted"))?
4943    }
4944
4945    let mut account_created_at = user.created_at;
4946    if let Some(github_created_at) = user.github_user_created_at {
4947        account_created_at = account_created_at.min(github_created_at);
4948    }
4949    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4950        Err(anyhow!("account too young"))?
4951    }
4952    let token = LlmTokenClaims::create(
4953        user.id,
4954        user.github_login.clone(),
4955        session.is_staff(),
4956        has_llm_closed_beta_feature_flag,
4957        session.current_plan(db).await?,
4958        &session.app_state.config,
4959    )?;
4960    response.send(proto::GetLlmTokenResponse { token })?;
4961    Ok(())
4962}
4963
4964fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4965    let message = match message {
4966        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4967        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4968        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4969        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4970        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4971            code: frame.code.into(),
4972            reason: frame.reason,
4973        })),
4974        // We should never receive a frame while reading the message, according
4975        // to the `tungstenite` maintainers:
4976        //
4977        // > It cannot occur when you read messages from the WebSocket, but it
4978        // > can be used when you want to send the raw frames (e.g. you want to
4979        // > send the frames to the WebSocket without composing the full message first).
4980        // >
4981        // > — https://github.com/snapview/tungstenite-rs/issues/268
4982        TungsteniteMessage::Frame(_) => {
4983            bail!("received an unexpected frame while reading the message")
4984        }
4985    };
4986
4987    Ok(message)
4988}
4989
4990fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4991    match message {
4992        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4993        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4994        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4995        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4996        AxumMessage::Close(frame) => {
4997            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4998                code: frame.code.into(),
4999                reason: frame.reason,
5000            }))
5001        }
5002    }
5003}
5004
5005fn notify_membership_updated(
5006    connection_pool: &mut ConnectionPool,
5007    result: MembershipUpdated,
5008    user_id: UserId,
5009    peer: &Peer,
5010) {
5011    for membership in &result.new_channels.channel_memberships {
5012        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5013    }
5014    for channel_id in &result.removed_channels {
5015        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5016    }
5017
5018    let user_channels_update = proto::UpdateUserChannels {
5019        channel_memberships: result
5020            .new_channels
5021            .channel_memberships
5022            .iter()
5023            .map(|cm| proto::ChannelMembership {
5024                channel_id: cm.channel_id.to_proto(),
5025                role: cm.role.into(),
5026            })
5027            .collect(),
5028        ..Default::default()
5029    };
5030
5031    let mut update = build_channels_update(result.new_channels);
5032    update.delete_channels = result
5033        .removed_channels
5034        .into_iter()
5035        .map(|id| id.to_proto())
5036        .collect();
5037    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5038
5039    for connection_id in connection_pool.user_connection_ids(user_id) {
5040        peer.send(connection_id, user_channels_update.clone())
5041            .trace_err();
5042        peer.send(connection_id, update.clone()).trace_err();
5043    }
5044}
5045
5046fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5047    proto::UpdateUserChannels {
5048        channel_memberships: channels
5049            .channel_memberships
5050            .iter()
5051            .map(|m| proto::ChannelMembership {
5052                channel_id: m.channel_id.to_proto(),
5053                role: m.role.into(),
5054            })
5055            .collect(),
5056        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5057        observed_channel_message_id: channels.observed_channel_messages.clone(),
5058    }
5059}
5060
5061fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5062    let mut update = proto::UpdateChannels::default();
5063
5064    for channel in channels.channels {
5065        update.channels.push(channel.to_proto());
5066    }
5067
5068    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5069    update.latest_channel_message_ids = channels.latest_channel_messages;
5070
5071    for (channel_id, participants) in channels.channel_participants {
5072        update
5073            .channel_participants
5074            .push(proto::ChannelParticipants {
5075                channel_id: channel_id.to_proto(),
5076                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5077            });
5078    }
5079
5080    for channel in channels.invited_channels {
5081        update.channel_invitations.push(channel.to_proto());
5082    }
5083
5084    update.hosted_projects = channels.hosted_projects;
5085    update
5086}
5087
5088fn build_initial_contacts_update(
5089    contacts: Vec<db::Contact>,
5090    pool: &ConnectionPool,
5091) -> proto::UpdateContacts {
5092    let mut update = proto::UpdateContacts::default();
5093
5094    for contact in contacts {
5095        match contact {
5096            db::Contact::Accepted { user_id, busy } => {
5097                update.contacts.push(contact_for_user(user_id, busy, pool));
5098            }
5099            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5100            db::Contact::Incoming { user_id } => {
5101                update
5102                    .incoming_requests
5103                    .push(proto::IncomingContactRequest {
5104                        requester_id: user_id.to_proto(),
5105                    })
5106            }
5107        }
5108    }
5109
5110    update
5111}
5112
5113fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5114    proto::Contact {
5115        user_id: user_id.to_proto(),
5116        online: pool.is_user_online(user_id),
5117        busy,
5118    }
5119}
5120
5121fn room_updated(room: &proto::Room, peer: &Peer) {
5122    broadcast(
5123        None,
5124        room.participants
5125            .iter()
5126            .filter_map(|participant| Some(participant.peer_id?.into())),
5127        |peer_id| {
5128            peer.send(
5129                peer_id,
5130                proto::RoomUpdated {
5131                    room: Some(room.clone()),
5132                },
5133            )
5134        },
5135    );
5136}
5137
5138fn channel_updated(
5139    channel: &db::channel::Model,
5140    room: &proto::Room,
5141    peer: &Peer,
5142    pool: &ConnectionPool,
5143) {
5144    let participants = room
5145        .participants
5146        .iter()
5147        .map(|p| p.user_id)
5148        .collect::<Vec<_>>();
5149
5150    broadcast(
5151        None,
5152        pool.channel_connection_ids(channel.root_id())
5153            .filter_map(|(channel_id, role)| {
5154                role.can_see_channel(channel.visibility)
5155                    .then_some(channel_id)
5156            }),
5157        |peer_id| {
5158            peer.send(
5159                peer_id,
5160                proto::UpdateChannels {
5161                    channel_participants: vec![proto::ChannelParticipants {
5162                        channel_id: channel.id.to_proto(),
5163                        participant_user_ids: participants.clone(),
5164                    }],
5165                    ..Default::default()
5166                },
5167            )
5168        },
5169    );
5170}
5171
5172async fn send_dev_server_projects_update(
5173    user_id: UserId,
5174    mut status: proto::DevServerProjectsUpdate,
5175    session: &Session,
5176) {
5177    let pool = session.connection_pool().await;
5178    for dev_server in &mut status.dev_servers {
5179        dev_server.status =
5180            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5181    }
5182    let connections = pool.user_connection_ids(user_id);
5183    for connection_id in connections {
5184        session.peer.send(connection_id, status.clone()).trace_err();
5185    }
5186}
5187
5188async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5189    let db = session.db().await;
5190
5191    let contacts = db.get_contacts(user_id).await?;
5192    let busy = db.is_user_busy(user_id).await?;
5193
5194    let pool = session.connection_pool().await;
5195    let updated_contact = contact_for_user(user_id, busy, &pool);
5196    for contact in contacts {
5197        if let db::Contact::Accepted {
5198            user_id: contact_user_id,
5199            ..
5200        } = contact
5201        {
5202            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5203                session
5204                    .peer
5205                    .send(
5206                        contact_conn_id,
5207                        proto::UpdateContacts {
5208                            contacts: vec![updated_contact.clone()],
5209                            remove_contacts: Default::default(),
5210                            incoming_requests: Default::default(),
5211                            remove_incoming_requests: Default::default(),
5212                            outgoing_requests: Default::default(),
5213                            remove_outgoing_requests: Default::default(),
5214                        },
5215                    )
5216                    .trace_err();
5217            }
5218        }
5219    }
5220    Ok(())
5221}
5222
5223async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5224    log::info!("lost dev server connection, unsharing projects");
5225    let project_ids = session
5226        .db()
5227        .await
5228        .get_stale_dev_server_projects(session.connection_id)
5229        .await?;
5230
5231    for project_id in project_ids {
5232        // not unshare re-checks the connection ids match, so we get away with no transaction
5233        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5234    }
5235
5236    let user_id = session.dev_server().user_id;
5237    let update = session
5238        .db()
5239        .await
5240        .dev_server_projects_update(user_id)
5241        .await?;
5242
5243    send_dev_server_projects_update(user_id, update, session).await;
5244
5245    Ok(())
5246}
5247
5248async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5249    let mut contacts_to_update = HashSet::default();
5250
5251    let room_id;
5252    let canceled_calls_to_user_ids;
5253    let live_kit_room;
5254    let delete_live_kit_room;
5255    let room;
5256    let channel;
5257
5258    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5259        contacts_to_update.insert(session.user_id());
5260
5261        for project in left_room.left_projects.values() {
5262            project_left(project, session);
5263        }
5264
5265        room_id = RoomId::from_proto(left_room.room.id);
5266        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5267        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5268        delete_live_kit_room = left_room.deleted;
5269        room = mem::take(&mut left_room.room);
5270        channel = mem::take(&mut left_room.channel);
5271
5272        room_updated(&room, &session.peer);
5273    } else {
5274        return Ok(());
5275    }
5276
5277    if let Some(channel) = channel {
5278        channel_updated(
5279            &channel,
5280            &room,
5281            &session.peer,
5282            &*session.connection_pool().await,
5283        );
5284    }
5285
5286    {
5287        let pool = session.connection_pool().await;
5288        for canceled_user_id in canceled_calls_to_user_ids {
5289            for connection_id in pool.user_connection_ids(canceled_user_id) {
5290                session
5291                    .peer
5292                    .send(
5293                        connection_id,
5294                        proto::CallCanceled {
5295                            room_id: room_id.to_proto(),
5296                        },
5297                    )
5298                    .trace_err();
5299            }
5300            contacts_to_update.insert(canceled_user_id);
5301        }
5302    }
5303
5304    for contact_user_id in contacts_to_update {
5305        update_user_contacts(contact_user_id, session).await?;
5306    }
5307
5308    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5309        live_kit
5310            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5311            .await
5312            .trace_err();
5313
5314        if delete_live_kit_room {
5315            live_kit.delete_room(live_kit_room).await.trace_err();
5316        }
5317    }
5318
5319    Ok(())
5320}
5321
5322async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5323    let left_channel_buffers = session
5324        .db()
5325        .await
5326        .leave_channel_buffers(session.connection_id)
5327        .await?;
5328
5329    for left_buffer in left_channel_buffers {
5330        channel_buffer_updated(
5331            session.connection_id,
5332            left_buffer.connections,
5333            &proto::UpdateChannelBufferCollaborators {
5334                channel_id: left_buffer.channel_id.to_proto(),
5335                collaborators: left_buffer.collaborators,
5336            },
5337            &session.peer,
5338        );
5339    }
5340
5341    Ok(())
5342}
5343
5344fn project_left(project: &db::LeftProject, session: &UserSession) {
5345    for connection_id in &project.connection_ids {
5346        if project.should_unshare {
5347            session
5348                .peer
5349                .send(
5350                    *connection_id,
5351                    proto::UnshareProject {
5352                        project_id: project.id.to_proto(),
5353                    },
5354                )
5355                .trace_err();
5356        } else {
5357            session
5358                .peer
5359                .send(
5360                    *connection_id,
5361                    proto::RemoveProjectCollaborator {
5362                        project_id: project.id.to_proto(),
5363                        peer_id: Some(session.connection_id.into()),
5364                    },
5365                )
5366                .trace_err();
5367        }
5368    }
5369}
5370
5371pub trait ResultExt {
5372    type Ok;
5373
5374    fn trace_err(self) -> Option<Self::Ok>;
5375}
5376
5377impl<T, E> ResultExt for Result<T, E>
5378where
5379    E: std::fmt::Debug,
5380{
5381    type Ok = T;
5382
5383    #[track_caller]
5384    fn trace_err(self) -> Option<T> {
5385        match self {
5386            Ok(value) => Some(value),
5387            Err(error) => {
5388                tracing::error!("{:?}", error);
5389                None
5390            }
5391        }
5392    }
5393}